mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] refactored the autoparallel module for organization (#1706)
* [autoparallel] refactored the autoparallel module for organization * polish codepull/1707/head
parent
91cd34e6e0
commit
6c331a5a09
|
@ -1,12 +0,0 @@
|
|||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .graph_analysis import GraphAnalyser
|
||||
from .solver import Solver
|
||||
from .cost_graph import CostGraph
|
||||
from .strategies_constructor import StrategiesConstructor
|
||||
from .constants import *
|
||||
from .options import SolverOptions
|
||||
|
||||
__all__ = [
|
||||
'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph',
|
||||
'SolverOptions'
|
||||
]
|
|
@ -1,16 +1,17 @@
|
|||
from .batch_norm_handler import BatchNormModuleHandler
|
||||
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||
from .dot_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .layer_norm_handler import LayerNormModuleHandler
|
||||
from .batch_norm_handler import BatchNormModuleHandler
|
||||
from .conv_handler import ConvModuleHandler, ConvFunctionHandler
|
||||
from .where_handler import WhereHandler
|
||||
from .unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from .reshape_handler import ReshapeHandler
|
||||
from .placeholder_handler import PlacehodlerHandler
|
||||
from .output_handler import OuputHandler
|
||||
from .normal_pooling_handler import NormPoolingHandler
|
||||
from .output_handler import OuputHandler
|
||||
from .placeholder_handler import PlacehodlerHandler
|
||||
from .registry import operator_registry
|
||||
from .reshape_handler import ReshapeHandler
|
||||
from .unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from .where_handler import WhereHandler
|
||||
|
||||
__all__ = [
|
||||
'LinearFunctionHandler', 'LinearModuleHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler',
|
||||
'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler',
|
||||
'OuputHandler', 'WhereHandler', 'NormPoolingHandler'
|
||||
'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry'
|
||||
]
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
|
||||
from ..strategy import BatchNormStrategyGenerator, StrategyGenerator
|
||||
from typing import List, Dict
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import ModuleHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['BatchNormModuleHandler']
|
||||
|
|
@ -1,12 +1,14 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
|
||||
from ..strategy import ConvStrategyGenerator, StrategyGenerator
|
||||
from typing import List, Dict
|
||||
from .registry import operator_registry
|
||||
|
||||
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
|
||||
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.nn.Conv1d)
|
|
@ -1,13 +1,16 @@
|
|||
from copy import deepcopy
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim)
|
||||
from colossalai.tensor.sharding_spec import ShardingException
|
||||
|
||||
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
|
||||
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator, BatchedMatMulStrategyGenerator
|
||||
from typing import List, Dict, Union
|
||||
from .registry import operator_registry
|
||||
from copy import deepcopy
|
||||
from .utils import switch_partition_dim, update_partition_dim
|
||||
from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyGenerator, StrategyGenerator)
|
||||
|
||||
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
|
||||
|
|
@ -1,10 +1,12 @@
|
|||
import torch
|
||||
from .node_handler import NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
|
||||
from ..strategy import TensorStrategyGenerator, TensorTupleStrategyGenerator, StrategyGenerator
|
||||
from typing import List, Dict
|
||||
from .registry import operator_registry
|
||||
import operator
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import (StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
|
||||
|
||||
__all__ = ['GetItemHandler']
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import ModuleHandler
|
||||
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
|
||||
from ..strategy import LayerNormGenerator, StrategyGenerator
|
||||
from typing import List, Dict
|
||||
from .registry import operator_registry
|
||||
from .strategy import LayerNormGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['LayerNormModuleHandler']
|
||||
|
|
@ -1,11 +1,14 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, ShardingStrategy, StrategiesVector,
|
||||
TrainCycleItem)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from typing import Dict, List, Union
|
||||
from ..sharding_strategy import ShardingStrategy, StrategiesVector, OperationData, TrainCycleItem
|
||||
from ..strategy import StrategyGenerator
|
||||
from .._utils import generate_resharding_costs
|
||||
|
||||
from .strategy import StrategyGenerator
|
||||
|
||||
|
||||
class NodeHandler(ABC):
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
|
||||
from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator
|
||||
from typing import List, Dict
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import ModuleHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['NormPoolingHandler']
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
|
||||
from colossalai.auto_parallel.solver.strategy import StrategyGenerator
|
||||
from colossalai.auto_parallel.solver.strategy.output_generator import OutputGenerator
|
||||
from typing import List, Dict
|
||||
from .registry import operator_registry
|
||||
from .strategy import OutputGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['OuputHandler']
|
||||
|
|
@ -1,10 +1,8 @@
|
|||
import torch
|
||||
from typing import Dict, List
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
|
||||
from colossalai.auto_parallel.solver.strategy import StrategyGenerator
|
||||
from colossalai.auto_parallel.solver.strategy.placeholder_generator import PlaceholderGenerator
|
||||
from typing import List, Dict
|
||||
from .registry import operator_registry
|
||||
from .strategy import PlaceholderGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['PlacehodlerHandler']
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
|
||||
from ..strategy import ReshapeGenerator, StrategyGenerator
|
||||
from typing import List, Dict
|
||||
from .registry import operator_registry
|
||||
import operator
|
||||
from .strategy import ReshapeGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['ReshapeHandler']
|
||||
|
|
@ -1,15 +1,16 @@
|
|||
from .strategy_generator import StrategyGenerator
|
||||
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator
|
||||
from .conv_strategy_generator import ConvStrategyGenerator
|
||||
from .batch_norm_generator import BatchNormStrategyGenerator
|
||||
from .unary_elementwise_generator import UnaryElementwiseGenerator
|
||||
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
|
||||
from .conv_strategy_generator import ConvStrategyGenerator
|
||||
from .getitem_generator import (GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
|
||||
from .layer_norm_generator import LayerNormGenerator
|
||||
from .where_generator import WhereGenerator
|
||||
from .reshape_generator import ReshapeGenerator
|
||||
from .matmul_strategy_generator import (BatchedMatMulStrategyGenerator, DotProductStrategyGenerator,
|
||||
LinearProjectionStrategyGenerator, MatVecStrategyGenerator)
|
||||
from .normal_pooling_generator import NormalPoolStrategyGenerator
|
||||
from .placeholder_generator import PlaceholderGenerator
|
||||
from .output_generator import OutputGenerator
|
||||
from .placeholder_generator import PlaceholderGenerator
|
||||
from .reshape_generator import ReshapeGenerator
|
||||
from .strategy_generator import StrategyGenerator
|
||||
from .unary_elementwise_generator import UnaryElementwiseGenerator
|
||||
from .where_generator import WhereGenerator
|
||||
|
||||
__all__ = [
|
||||
'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',
|
|
@ -1,11 +1,11 @@
|
|||
import copy
|
||||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
from typing import List
|
||||
from .._utils import exception_handler
|
||||
import copy
|
||||
|
||||
__all__ = ['BatchNormStrategyGenerator']
|
||||
|
|
@ -1,12 +1,14 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from .strategy_generator import StrategyGenerator
|
||||
from typing import List
|
||||
from .._utils import exception_handler
|
||||
import warnings
|
||||
import copy
|
||||
import operator
|
||||
import warnings
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import exception_handler
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
||||
|
||||
class ConvStrategyGenerator(StrategyGenerator):
|
|
@ -1,12 +1,10 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from .strategy_generator import FollowingStrategyGenerator
|
||||
from typing import List
|
||||
from .._utils import exception_handler
|
||||
import copy
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import FollowingStrategyGenerator
|
||||
|
||||
__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']
|
||||
|
||||
|
|
@ -1,11 +1,13 @@
|
|||
import copy
|
||||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
from typing import List
|
||||
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
|
||||
import copy
|
||||
|
||||
__all__ = ['LayerNormGenerator']
|
||||
|
|
@ -1,11 +1,12 @@
|
|||
from audioop import bias
|
||||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from .strategy_generator import StrategyGenerator
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
||||
|
||||
class MatMulStrategyGenerator(StrategyGenerator):
|
||||
"""
|
|
@ -1,11 +1,13 @@
|
|||
import copy
|
||||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from .strategy_generator import StrategyGenerator
|
||||
from typing import List
|
||||
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
|
||||
import copy
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding)
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
||||
|
||||
class NormalPoolStrategyGenerator(StrategyGenerator):
|
|
@ -1,11 +1,6 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
|
||||
from .strategy_generator import OutputStrategyGenerator
|
||||
from typing import List
|
||||
from .._utils import exception_handler
|
||||
import copy
|
||||
|
||||
__all__ = ['OutputGenerator']
|
||||
|
||||
|
@ -46,7 +41,7 @@ class OutputGenerator(OutputStrategyGenerator):
|
|||
communication_action_mapping = {}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
name = f'Replica Output'
|
||||
name = 'Replica Output'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
|
@ -1,11 +1,6 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
from typing import List
|
||||
from .._utils import exception_handler
|
||||
import copy
|
||||
|
||||
__all__ = ['PlaceholderGenerator']
|
||||
|
||||
|
@ -47,7 +42,7 @@ class PlaceholderGenerator(StrategyGenerator):
|
|||
communication_action_mapping = {}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
name = f'Replica Placeholder'
|
||||
name = 'Replica Placeholder'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
|
@ -1,11 +1,10 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from .strategy_generator import FollowingStrategyGenerator
|
||||
from typing import List
|
||||
import copy
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import FollowingStrategyGenerator
|
||||
|
||||
__all__ = ['ReshapeGenerator']
|
||||
|
||||
|
|
@ -1,15 +1,16 @@
|
|||
import operator
|
||||
import torch
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from functools import reduce
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import reduce
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
|
||||
TrainCycleItem)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from typing import Dict, List, Union, Any
|
||||
from ..sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem, OperationDataType
|
||||
from torch.fx import Node
|
||||
import copy
|
||||
|
||||
|
||||
class StrategyGenerator(ABC):
|
|
@ -1,12 +1,9 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from .strategy_generator import FollowingStrategyGenerator
|
||||
from typing import List
|
||||
from .._utils import exception_handler
|
||||
import copy
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
|
||||
from .strategy_generator import FollowingStrategyGenerator
|
||||
|
||||
__all__ = ['UnaryElementwiseGenerator']
|
||||
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from .strategy_generator import StrategyGenerator, FollowingStrategyGenerator
|
||||
from typing import List
|
||||
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
|
||||
import copy
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding)
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
||||
__all__ = ['WhereGenerator']
|
||||
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
|
||||
from ..strategy import UnaryElementwiseGenerator, StrategyGenerator
|
||||
from typing import List, Dict
|
||||
from .registry import operator_registry
|
||||
import operator
|
||||
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
|
||||
|
||||
__all__ = ['UnaryElementwiseHandler']
|
||||
|
|
@ -1,12 +1,14 @@
|
|||
import torch
|
||||
from .node_handler import NodeHandler
|
||||
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
|
||||
from ..strategy import WhereGenerator, StrategyGenerator
|
||||
from .broadcast import recover_sharding_spec_for_broadcast_shape
|
||||
from typing import List, Dict
|
||||
from .registry import operator_registry
|
||||
import operator
|
||||
import copy
|
||||
import operator
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, StrategiesVector)
|
||||
from ..utils import recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import StrategyGenerator, WhereGenerator
|
||||
|
||||
__all__ = ['WhereHandler']
|
||||
|
|
@ -1,17 +1,14 @@
|
|||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import operator
|
||||
import torch
|
||||
from functools import reduce
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
import torch
|
||||
from colossalai.tensor.shape_consistency import CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||
from typing import Dict, List, Union, Tuple, Any
|
||||
from torch.fx.node import Node
|
||||
from .constants import *
|
||||
|
||||
from .constants import (BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP)
|
||||
|
||||
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
|
||||
|
||||
|
@ -75,6 +72,11 @@ class TrainCycleItem:
|
|||
@dataclass
|
||||
class MemoryCost:
|
||||
"""
|
||||
MemoryCost is a dataclass which stores the memory usage in the program.
|
||||
|
||||
Args:
|
||||
activation (int): the memory cost incurred by the activations in bytes.
|
||||
parameter (int): the memory cost incurred by the module parameter in bytes.
|
||||
"""
|
||||
activation: int = 0
|
||||
parameter: int = 0
|
|
@ -0,0 +1,7 @@
|
|||
from .cost_graph import CostGraph
|
||||
from .graph_analysis import GraphAnalyser
|
||||
from .options import SolverOptions
|
||||
from .solver import Solver
|
||||
from .strategies_constructor import StrategiesConstructor
|
||||
|
||||
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions']
|
|
@ -1,7 +1,4 @@
|
|||
from typing import List
|
||||
import math
|
||||
from torch.fx.node import Node
|
||||
from colossalai.auto_parallel.solver.constants import INFINITY_COST
|
||||
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
|
||||
|
||||
|
||||
class CostGraph:
|
|
@ -1,9 +1,10 @@
|
|||
from dataclasses import dataclass
|
||||
from torch.fx.node import Node
|
||||
from typing import List
|
||||
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from collections import OrderedDict as ODict
|
||||
from typing import List, OrderedDict, Union, Any
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.fx.passes.utils import get_node_module
|
||||
|
||||
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
|
|
@ -1,18 +1,21 @@
|
|||
import warnings
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import multiprocessing
|
||||
from torch.fx.node import Node
|
||||
from torch.fx.graph import Graph
|
||||
from . import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
import time
|
||||
import warnings
|
||||
from typing import Dict
|
||||
from .constants import INFINITY_COST
|
||||
|
||||
import numpy as np
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
|
||||
|
||||
from .cost_graph import CostGraph
|
||||
from .graph_analysis import GraphAnalyser
|
||||
from .strategies_constructor import StrategiesConstructor
|
||||
|
||||
try:
|
||||
import pulp
|
||||
from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus
|
||||
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
|
||||
except:
|
||||
warnings.warn(f'please install the pulp')
|
||||
|
||||
|
@ -21,454 +24,6 @@ __all___ = ['Solver']
|
|||
|
||||
class Solver:
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
cost_graph: CostGraph,
|
||||
graph_analyser: GraphAnalyser,
|
||||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
memory_increasing_coefficient: float = 1.3):
|
||||
'''
|
||||
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
||||
|
||||
Argument:
|
||||
graph: The computing graph to be optimized.
|
||||
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
|
||||
cost_graph: A graph data structure to simplify the edge cost graph.
|
||||
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
|
||||
memory_budget: Memory constraint for the solution.
|
||||
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
|
||||
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
|
||||
'''
|
||||
self.graph = graph
|
||||
self.strategies_constructor = strategies_constructor
|
||||
self.cost_graph = cost_graph
|
||||
self.graph_analyser = graph_analyser
|
||||
self.leaf_strategies = self.strategies_constructor.leaf_strategies
|
||||
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
|
||||
self.strategy_map = self.strategies_constructor.strategy_map
|
||||
self.memory_budget = memory_budget
|
||||
self.solution_numbers = solution_numbers
|
||||
if self.solution_numbers > 1:
|
||||
self.memory_increasing_coefficient = memory_increasing_coefficient
|
||||
else:
|
||||
self.memory_increasing_coefficient = 1
|
||||
self.liveness_list = self.graph_analyser.liveness_analysis()
|
||||
self.node_index_dict = self._generate_node_index_dict()
|
||||
# The last solution vector of auto sharding.
|
||||
self.last_s_val = None
|
||||
# The last objective value of the best ILP solution.
|
||||
self.last_objective = None
|
||||
|
||||
def _recover_merged_node_strategy(self):
|
||||
'''
|
||||
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
|
||||
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
|
||||
node.
|
||||
'''
|
||||
for node_index, node in enumerate(self.nodes):
|
||||
if node.strategies_vector.check_merge():
|
||||
# the merged node has only one input, and its strategies follow the input sharding strategy
|
||||
input_strategies_vector = node.args[0].strategies_vector
|
||||
input_best_strategy_index = self.last_s_val[node_index - 1]
|
||||
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
|
||||
for strategy_index, strategy in enumerate(node.strategies_vector):
|
||||
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
|
||||
self.last_s_val[node_index] = strategy_index
|
||||
break
|
||||
|
||||
def _generate_node_index_dict(self) -> Dict[Node, int]:
|
||||
node_index_dict = {}
|
||||
for index, strategies_vector in enumerate(self.leaf_strategies):
|
||||
node_index_dict[strategies_vector.node] = index
|
||||
return node_index_dict
|
||||
|
||||
def _prepare_data_for_solver(self):
|
||||
'''
|
||||
Extract information from components for solver.
|
||||
'''
|
||||
node_nums = len(self.leaf_strategies)
|
||||
memory_budget = self.memory_budget
|
||||
|
||||
# prepare strategies_len
|
||||
strategies_len = []
|
||||
for node in self.nodes:
|
||||
strategies_len.append(self.cost_graph.node_lens[node])
|
||||
strategies_len = np.array(strategies_len)
|
||||
|
||||
# prepare following_nodes
|
||||
following_nodes = self.cost_graph.following_dict
|
||||
index_following_nodes = {}
|
||||
for src, target in following_nodes.items():
|
||||
src_index = self.node_index_dict[src]
|
||||
target_index = self.node_index_dict[target]
|
||||
index_following_nodes[src_index] = target_index
|
||||
following_nodes = index_following_nodes
|
||||
for index in range(node_nums):
|
||||
if index not in following_nodes:
|
||||
following_nodes[index] = -1
|
||||
|
||||
# prepare edge_pairs and resharding costs
|
||||
edge_pairs = []
|
||||
resharding_costs = []
|
||||
for pairs, edge_cost in self.cost_graph.edge_costs.items():
|
||||
src_node = pairs[0]
|
||||
dst_node = pairs[1]
|
||||
src_node_index = self.node_index_dict[src_node]
|
||||
dst_node_index = self.node_index_dict[dst_node]
|
||||
edge_pairs.append(src_node_index)
|
||||
edge_pairs.append(dst_node_index)
|
||||
|
||||
for i in range(strategies_len[src_node_index]):
|
||||
for j in range(strategies_len[dst_node_index]):
|
||||
resharding_costs.append(edge_cost[(i, j)])
|
||||
edge_pairs = np.array(edge_pairs)
|
||||
resharding_costs = np.array(resharding_costs)
|
||||
|
||||
# prepare liveness_set
|
||||
liveness_set = self.liveness_list
|
||||
|
||||
# omit alias_set now
|
||||
alias_set = None
|
||||
alias_convert_costs = None
|
||||
|
||||
# prepare compute_costs, communication_costs and memory_costs
|
||||
compute_costs = []
|
||||
communication_costs = []
|
||||
memory_costs = []
|
||||
extra_node_costs = self.cost_graph.extra_node_costs
|
||||
for strategies_vector in self.leaf_strategies:
|
||||
node = strategies_vector.node
|
||||
for index, strategy in enumerate(strategies_vector):
|
||||
compute_costs.append(strategy.compute_cost)
|
||||
# node in extra_node_costs means it has some extra communication
|
||||
# cost from node merging, so we need to add those extra communication
|
||||
# cost into
|
||||
if node in extra_node_costs:
|
||||
origin_communication_cost = strategy.communication_cost
|
||||
extra_node_cost = extra_node_costs[node][index]
|
||||
communication_cost = origin_communication_cost + extra_node_cost
|
||||
communication_costs.append(communication_cost)
|
||||
else:
|
||||
communication_costs.append(strategy.communication_cost)
|
||||
# temporarily we just consider the forward memory cost
|
||||
memory_cost = strategy.memory_cost
|
||||
if isinstance(memory_cost, tuple):
|
||||
memory_costs.append(memory_cost[0])
|
||||
else:
|
||||
memory_costs.append(memory_cost)
|
||||
compute_costs = np.array(compute_costs)
|
||||
communication_costs = np.array(communication_costs)
|
||||
memory_costs = np.array(memory_costs)
|
||||
|
||||
# omit initial value for nodes
|
||||
s_init_np = None
|
||||
|
||||
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
|
||||
|
||||
def _call_solver_serialized_args(self,
|
||||
node_nums,
|
||||
memory_budget,
|
||||
strategies_len,
|
||||
following_nodes,
|
||||
edge_pairs,
|
||||
alias_set,
|
||||
liveness_set,
|
||||
compute_costs,
|
||||
communication_costs,
|
||||
memory_costs,
|
||||
resharding_costs,
|
||||
alias_convert_costs,
|
||||
s_init_np=None):
|
||||
"""
|
||||
Call the solver with serialized arguments.
|
||||
"""
|
||||
|
||||
tic = time.time()
|
||||
|
||||
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
|
||||
assert isinstance(x, np.ndarray)
|
||||
assert len(strategies_len) == node_nums, "strategies_len"
|
||||
|
||||
def get_non_zero_index(binary_vector):
|
||||
"""
|
||||
Get the index of non-zero item in a vector.
|
||||
"""
|
||||
ct = 0
|
||||
ret = None
|
||||
for i, elem in enumerate(binary_vector):
|
||||
if pulp.value(elem):
|
||||
ret = i
|
||||
ct += 1
|
||||
|
||||
assert ct == 1
|
||||
return ret
|
||||
|
||||
# 0. Unpack flatten numpy arrays
|
||||
s_follow = following_nodes
|
||||
|
||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||
r = []
|
||||
pt = 0
|
||||
edge_set = set()
|
||||
for (i, j) in E:
|
||||
prod_length = strategies_len[i] * strategies_len[j]
|
||||
|
||||
if (i, j) in edge_set:
|
||||
raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||
|
||||
edge_set.add((i, j))
|
||||
r.append(resharding_costs[pt:pt + prod_length])
|
||||
pt += prod_length
|
||||
assert pt == len(resharding_costs)
|
||||
|
||||
######################
|
||||
# omit alias set now #
|
||||
######################
|
||||
|
||||
# A = alias_set.reshape((-1, 2)) # noqa
|
||||
# for (i, j) in A:
|
||||
# prod_length = strategies_len[i] * strategies_len[j]
|
||||
# v.append(alias_convert_costs[pt:pt + prod_length])
|
||||
# pt += prod_length
|
||||
# assert pt == len(alias_convert_costs)
|
||||
|
||||
# L = [] # noqa
|
||||
# pt = node_nums
|
||||
# for i in range(node_nums):
|
||||
# length = liveness_set[i]
|
||||
# L.append(liveness_set[pt:pt + length])
|
||||
# pt += length
|
||||
# assert pt == len(liveness_set)
|
||||
v = []
|
||||
pt = 0
|
||||
|
||||
c = []
|
||||
d = []
|
||||
m = []
|
||||
pt = 0
|
||||
for i in range(node_nums):
|
||||
length = strategies_len[i]
|
||||
c.append(compute_costs[pt:pt + length])
|
||||
d.append(communication_costs[pt:pt + length])
|
||||
m.append(memory_costs[pt:pt + length])
|
||||
pt += length
|
||||
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
|
||||
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
|
||||
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
|
||||
|
||||
# 1. Create variables
|
||||
|
||||
#############################
|
||||
# create variables for node #
|
||||
#############################
|
||||
s = []
|
||||
num_nodes = 0
|
||||
reverse_follow_backpatch = []
|
||||
for i in range(node_nums):
|
||||
if s_follow[i] < 0:
|
||||
if strategies_len[i] == 1:
|
||||
s.append([1])
|
||||
else:
|
||||
num_nodes += 1
|
||||
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
|
||||
else:
|
||||
if s_follow[i] < len(s):
|
||||
s.append(s[s_follow[i]])
|
||||
else:
|
||||
s.append(None)
|
||||
reverse_follow_backpatch.append(i)
|
||||
|
||||
for i in reverse_follow_backpatch:
|
||||
s[i] = s[s_follow[i]]
|
||||
|
||||
#############################
|
||||
# create variables for edge #
|
||||
#############################
|
||||
e = []
|
||||
num_edges = 0
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if len(s[i]) == 1:
|
||||
e.append(s[j])
|
||||
elif len(s[j]) == 1:
|
||||
e.append(s[i])
|
||||
else:
|
||||
num_edges += 1
|
||||
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
|
||||
assert len(e[idx]) == len(r[idx])
|
||||
for element in s:
|
||||
assert len(element) > 0
|
||||
# 2. Set initial value
|
||||
######################################
|
||||
# set a initial value for warm start #
|
||||
######################################
|
||||
if s_init_np is not None:
|
||||
s_init = s_init_np.reshape((-1, 3))
|
||||
for (idx, value, fix) in s_init:
|
||||
for i in range(len(s[idx])):
|
||||
s[idx][i].setInitialValue(i == value)
|
||||
if fix:
|
||||
s[idx][i].fixValue()
|
||||
|
||||
# 3. Objective
|
||||
prob = LpProblem("myProblem", LpMinimize)
|
||||
###################################################################
|
||||
# computing the node cost(computing cost and communication cost) #
|
||||
###################################################################
|
||||
obj = 0
|
||||
for i in range(node_nums):
|
||||
assert len(s[i]) == len(c[i])
|
||||
assert len(s[i]) == len(d[i])
|
||||
|
||||
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
|
||||
|
||||
#############################################
|
||||
# computing the edge cost(resharding cost) #
|
||||
#############################################
|
||||
for i in range(len(E)):
|
||||
assert len(e[i]) == len(r[i])
|
||||
obj += lpDot(e[i], r[i])
|
||||
|
||||
prob += obj
|
||||
|
||||
# 4. Constraints
|
||||
# (a). specified by `cat="Binary"`
|
||||
|
||||
# (b)
|
||||
#################################################
|
||||
# make sure each node only choose one strategy #
|
||||
#################################################
|
||||
for i in range(node_nums):
|
||||
if s_follow[i] < 0:
|
||||
prob += lpSum(s[i]) == 1
|
||||
|
||||
# (c)
|
||||
#################################################
|
||||
# compute memory consumption with liveness set #
|
||||
#################################################
|
||||
if memory_budget > 0:
|
||||
for liveness_stage in liveness_set:
|
||||
mem = 0
|
||||
for live_variable in liveness_stage.unique_live_vars:
|
||||
node_index = self.node_index_dict[live_variable.node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||
prob += mem <= memory_budget
|
||||
|
||||
# (d). specified by `cat="Binary"`
|
||||
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if strategies_len[i] == 1 or strategies_len[j] == 1:
|
||||
continue
|
||||
|
||||
# (e)
|
||||
prob += lpSum(e[idx]) == 1
|
||||
|
||||
# (f)
|
||||
for row in range(len(s[i])):
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
|
||||
|
||||
# (g)
|
||||
for col in range(len(s[j])):
|
||||
R = len(s[i]) # noqa
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
|
||||
|
||||
# (h)
|
||||
######################
|
||||
# omit alias set now #
|
||||
######################
|
||||
|
||||
# alias_set = set()
|
||||
# for (idx, (i, j)) in enumerate(A):
|
||||
# R = len(s[i]) # noqa
|
||||
# C = len(s[j]) # noqa
|
||||
# if (i, j) in alias_set:
|
||||
# raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||
|
||||
# alias_set.add((i, j))
|
||||
# alias_set.add((j, i))
|
||||
|
||||
# for row in range(len(s[i])):
|
||||
# for col in range(len(s[j])):
|
||||
# if v[idx][row * C + col] > 0.5:
|
||||
# prob += s[i][row] + s[j][col] <= 1
|
||||
|
||||
verbose = True
|
||||
|
||||
msg = verbose
|
||||
time_limit = 600
|
||||
assert "COIN_CMD" in pulp.listSolvers(
|
||||
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
|
||||
|
||||
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
|
||||
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
|
||||
prob.solve(solver)
|
||||
|
||||
status = prob.status
|
||||
objective = pulp.value(prob.objective)
|
||||
objective = float(objective) if objective is not None else -1.0
|
||||
if verbose:
|
||||
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
|
||||
f"Time: {time.time() - tic}")
|
||||
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
|
||||
|
||||
if prob.status in [pulp.LpStatusInfeasible]:
|
||||
raise RuntimeError("Cannot run the function under the given memory budget. "
|
||||
"Please increase the memory budget.")
|
||||
|
||||
# Get and check results
|
||||
s_val = np.full((node_nums,), -1, dtype=np.int32)
|
||||
for i in range(node_nums):
|
||||
s_val[i] = get_non_zero_index(s[i])
|
||||
|
||||
e_val = np.full((len(E),), -1, dtype=np.int32)
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
e_val[idx] = get_non_zero_index(e[idx])
|
||||
i_spec_index = e_val[idx] // len(s[j])
|
||||
j_spec_index = e_val[idx] % len(s[j])
|
||||
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
|
||||
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
|
||||
if verbose and r[idx][e_val[idx]] > 0:
|
||||
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
|
||||
|
||||
self.last_s_val = list(s_val)
|
||||
self._recover_merged_node_strategy()
|
||||
self.last_objective = objective
|
||||
|
||||
if objective > INFINITY_COST:
|
||||
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
|
||||
|
||||
return self.last_s_val, e_val, self.last_objective, status
|
||||
|
||||
def call_solver_serialized_args(self):
|
||||
"""
|
||||
Call the solver with serialized arguments and handle python errors. Additionally,
|
||||
we could give a serious of solutions with different memory budget.
|
||||
"""
|
||||
if self.solution_numbers == 1:
|
||||
args = self._prepare_data_for_solver()
|
||||
ret = self._call_solver_serialized_args(*args)
|
||||
|
||||
return ret
|
||||
|
||||
origin_memory_budget = self.memory_budget
|
||||
memory_budget_list = [
|
||||
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
|
||||
]
|
||||
ret_list = []
|
||||
for memory_budget in memory_budget_list:
|
||||
self.memory_budget = memory_budget
|
||||
args = self._prepare_data_for_solver()
|
||||
ret = self._call_solver_serialized_args(*args)
|
||||
ret_list.append(ret)
|
||||
|
||||
return ret_list
|
||||
|
||||
|
||||
class Solver_V2:
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
|
@ -480,7 +35,6 @@ class Solver_V2:
|
|||
memory_increasing_coefficient: float = 1.3):
|
||||
'''
|
||||
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
||||
|
||||
Argument:
|
||||
graph: The computing graph to be optimized.
|
||||
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
|
|
@ -1,22 +1,19 @@
|
|||
import math
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import (OuputHandler, PlacehodlerHandler, operator_registry)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (generate_resharding_costs, generate_sharding_spec)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.auto_parallel.solver.node_handler.registry import operator_registry
|
||||
from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .options import SolverOptions
|
||||
from . import ShardingStrategy, StrategiesVector
|
||||
from .node_handler import *
|
||||
from .constants import *
|
||||
from copy import deepcopy
|
||||
import math
|
||||
import torch
|
||||
import operator
|
||||
from typing import Dict, List
|
||||
from ._utils import generate_sharding_spec, generate_resharding_costs
|
||||
import builtins
|
||||
|
||||
__all__ = ['StrategiesConstructor']
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape)
|
||||
from .factory import generate_resharding_costs, generate_sharding_spec
|
||||
from .misc import exception_handler
|
||||
from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size,
|
||||
switch_partition_dim, update_partition_dim)
|
||||
|
||||
__all__ = [
|
||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'exception_handler', 'switch_partition_dim',
|
||||
'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding',
|
||||
'generate_sharding_size'
|
||||
]
|
|
@ -1,14 +1,17 @@
|
|||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from typing import Union, Dict, List, Optional
|
||||
import operator
|
||||
import warnings
|
||||
from functools import reduce
|
||||
import functools
|
||||
import operator
|
||||
from .constants import INFINITY_COST
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from torch.fx.node import Node
|
||||
|
||||
from ..constants import INFINITY_COST
|
||||
|
||||
__all__ = ['generate_sharding_spec', 'generate_resharding_costs']
|
||||
|
||||
|
||||
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
||||
|
@ -85,55 +88,3 @@ def generate_resharding_costs(nodes: List[Node],
|
|||
resharding_cost = INFINITY_COST
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
||||
|
||||
|
||||
def exception_handler(func):
|
||||
"""
|
||||
A function wrapper which executes the function with a specified seed.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
rst = func(*args, **kwargs)
|
||||
return rst
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 2D sharding cases
|
||||
for i in range(dim_size):
|
||||
for j in range(i + 1, dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
||||
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
dim_partition_list.append(dim_partition_dict_1)
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
||||
dim_partition_list.append(dim_partition_dict_flatten)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 1D sharding cases
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def generate_sharding_size(dim_partition_dict, device_mesh):
|
||||
total_sharding_size = 1
|
||||
for mesh_dim_list in dim_partition_dict.values():
|
||||
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
|
||||
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
|
||||
total_sharding_size *= sharding_size
|
||||
|
||||
return total_sharding_size
|
|
@ -0,0 +1,26 @@
|
|||
import functools
|
||||
import warnings
|
||||
|
||||
__all__ = ['exception_handler']
|
||||
|
||||
|
||||
def exception_handler(func):
|
||||
"""
|
||||
A function wrapper to handle the AssertionError in the function.
|
||||
|
||||
Usage:
|
||||
# mute the assertion error in the function
|
||||
@exception_handler
|
||||
def do_something():
|
||||
...
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
rst = func(*args, **kwargs)
|
||||
return rst
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
return wrapper
|
|
@ -1,7 +1,16 @@
|
|||
import torch
|
||||
from typing import Dict
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = [
|
||||
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||
]
|
||||
|
||||
|
||||
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
||||
|
@ -66,3 +75,39 @@ def update_partition_dim(sharding_spec: ShardingSpec,
|
|||
entire_shape=physical_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
return current_sharding_spec
|
||||
|
||||
|
||||
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 2D sharding cases
|
||||
for i in range(dim_size):
|
||||
for j in range(i + 1, dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
||||
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
dim_partition_list.append(dim_partition_dict_1)
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
||||
dim_partition_list.append(dim_partition_dict_flatten)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 1D sharding cases
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def generate_sharding_size(dim_partition_dict, device_mesh):
|
||||
total_sharding_size = 1
|
||||
for mesh_dim_list in dim_partition_dict.values():
|
||||
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
|
||||
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
|
||||
total_sharding_size *= sharding_size
|
||||
|
||||
return total_sharding_size
|
|
@ -1,7 +1,9 @@
|
|||
import torch
|
||||
from colossalai.auto_parallel.solver.node_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable,
|
||||
recover_sharding_spec_for_broadcast_shape)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
def test_is_broadcastable():
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.batch_norm_handler import BatchNormModuleHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import \
|
||||
BatchNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
def test_bn_module_handler():
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import \
|
||||
BMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvModuleHandler, ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import (ConvFunctionHandler, ConvModuleHandler)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
def test_conv_module_handler():
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.getitem_handler import GetItemHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
|
||||
ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import \
|
||||
GetItemHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
class GetItemModel(nn.Module):
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.layer_norm_handler import LayerNormModuleHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import \
|
||||
LayerNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
def test_ln_module_handler():
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.dot_handler import LinearModuleHandler, LinearFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import (LinearFunctionHandler, LinearModuleHandler)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
|
||||
StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import NormPoolingHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \
|
||||
NormPoolingHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
import pytest
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \
|
||||
OuputHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
||||
|
||||
class OutputModel(nn.Module):
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \
|
||||
PlacehodlerHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
||||
|
||||
class PlaceholderModel(nn.Module):
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.reshape_handler import ReshapeHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
|
||||
ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import \
|
||||
ReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
||||
|
||||
class ReshapeModel(nn.Module):
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
|
||||
ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \
|
||||
UnaryElementwiseHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
class ReLuModel(nn.Module):
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.where_handler import WhereHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \
|
||||
WhereHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
|
|
@ -1,24 +1,22 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
|
||||
StrategiesConstructor)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
|
||||
from colossalai.auto_parallel.solver.solver import Solver_V2
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (shape_consistency_pass,
|
||||
solution_annotatation_pass)
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
@ -61,7 +59,7 @@ def check_apply(rank, world_size, port):
|
|||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
|
|
@ -1,20 +1,13 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from torchvision.models import resnet50
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
|
||||
StrategiesConstructor)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.solver.solver import Solver
|
||||
from torchvision.models import resnet34, resnet50
|
||||
from colossalai.auto_parallel.solver.constants import *
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue