[autoparallel] refactored the autoparallel module for organization (#1706)

* [autoparallel] refactored the autoparallel module for organization

* polish code
pull/1707/head
Frank Lee 2022-10-14 13:27:00 +08:00 committed by GitHub
parent 91cd34e6e0
commit 6c331a5a09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
57 changed files with 408 additions and 799 deletions

View File

@ -1,12 +0,0 @@
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .graph_analysis import GraphAnalyser
from .solver import Solver
from .cost_graph import CostGraph
from .strategies_constructor import StrategiesConstructor
from .constants import *
from .options import SolverOptions
__all__ = [
'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph',
'SolverOptions'
]

View File

@ -1,16 +1,17 @@
from .batch_norm_handler import BatchNormModuleHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .dot_handler import LinearFunctionHandler, LinearModuleHandler
from .layer_norm_handler import LayerNormModuleHandler
from .batch_norm_handler import BatchNormModuleHandler
from .conv_handler import ConvModuleHandler, ConvFunctionHandler
from .where_handler import WhereHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .reshape_handler import ReshapeHandler
from .placeholder_handler import PlacehodlerHandler
from .output_handler import OuputHandler
from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OuputHandler
from .placeholder_handler import PlacehodlerHandler
from .registry import operator_registry
from .reshape_handler import ReshapeHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .where_handler import WhereHandler
__all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler',
'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler',
'OuputHandler', 'WhereHandler', 'NormPoolingHandler'
'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry'
]

View File

@ -1,10 +1,11 @@
from typing import Dict, List
import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import BatchNormStrategyGenerator, StrategyGenerator
from typing import List, Dict
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from .registry import operator_registry
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
__all__ = ['BatchNormModuleHandler']

View File

@ -1,12 +1,14 @@
from typing import Dict, List
import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import ConvStrategyGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
__all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
@operator_registry.register(torch.nn.Conv1d)

View File

@ -1,13 +1,16 @@
from copy import deepcopy
from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim)
from colossalai.tensor.sharding_spec import ShardingException
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator, BatchedMatMulStrategyGenerator
from typing import List, Dict, Union
from .registry import operator_registry
from copy import deepcopy
from .utils import switch_partition_dim, update_partition_dim
from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyGenerator, StrategyGenerator)
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']

View File

@ -1,10 +1,12 @@
import torch
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import TensorStrategyGenerator, TensorTupleStrategyGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry
import operator
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import (StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
__all__ = ['GetItemHandler']

View File

@ -1,9 +1,11 @@
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import LayerNormGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry
from .strategy import LayerNormGenerator, StrategyGenerator
__all__ = ['LayerNormModuleHandler']

View File

@ -1,11 +1,14 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Union
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, ShardingStrategy, StrategiesVector,
TrainCycleItem)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from typing import Dict, List, Union
from ..sharding_strategy import ShardingStrategy, StrategiesVector, OperationData, TrainCycleItem
from ..strategy import StrategyGenerator
from .._utils import generate_resharding_costs
from .strategy import StrategyGenerator
class NodeHandler(ABC):

View File

@ -1,10 +1,11 @@
from typing import Dict, List
import torch
import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator
from typing import List, Dict
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
__all__ = ['NormPoolingHandler']

View File

@ -1,10 +1,10 @@
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from colossalai.auto_parallel.solver.strategy import StrategyGenerator
from colossalai.auto_parallel.solver.strategy.output_generator import OutputGenerator
from typing import List, Dict
from .registry import operator_registry
from .strategy import OutputGenerator, StrategyGenerator
__all__ = ['OuputHandler']

View File

@ -1,10 +1,8 @@
import torch
from typing import Dict, List
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData
from colossalai.auto_parallel.solver.strategy import StrategyGenerator
from colossalai.auto_parallel.solver.strategy.placeholder_generator import PlaceholderGenerator
from typing import List, Dict
from .registry import operator_registry
from .strategy import PlaceholderGenerator, StrategyGenerator
__all__ = ['PlacehodlerHandler']

View File

@ -1,10 +1,11 @@
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import ReshapeGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry
import operator
from .strategy import ReshapeGenerator, StrategyGenerator
__all__ = ['ReshapeHandler']

View File

@ -1,15 +1,16 @@
from .strategy_generator import StrategyGenerator
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator
from .batch_norm_generator import BatchNormStrategyGenerator
from .unary_elementwise_generator import UnaryElementwiseGenerator
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator
from .getitem_generator import (GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator)
from .layer_norm_generator import LayerNormGenerator
from .where_generator import WhereGenerator
from .reshape_generator import ReshapeGenerator
from .matmul_strategy_generator import (BatchedMatMulStrategyGenerator, DotProductStrategyGenerator,
LinearProjectionStrategyGenerator, MatVecStrategyGenerator)
from .normal_pooling_generator import NormalPoolStrategyGenerator
from .placeholder_generator import PlaceholderGenerator
from .output_generator import OutputGenerator
from .placeholder_generator import PlaceholderGenerator
from .reshape_generator import ReshapeGenerator
from .strategy_generator import StrategyGenerator
from .unary_elementwise_generator import UnaryElementwiseGenerator
from .where_generator import WhereGenerator
__all__ = [
'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',

View File

@ -1,11 +1,11 @@
import copy
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler
import copy
__all__ = ['BatchNormStrategyGenerator']

View File

@ -1,12 +1,14 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler
import warnings
import copy
import operator
import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import exception_handler
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
class ConvStrategyGenerator(StrategyGenerator):

View File

@ -1,12 +1,10 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
from typing import List
from .._utils import exception_handler
import copy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']

View File

@ -1,11 +1,13 @@
import copy
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy
__all__ = ['LayerNormGenerator']

View File

@ -1,11 +1,12 @@
from audioop import bias
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
class MatMulStrategyGenerator(StrategyGenerator):
"""

View File

@ -1,11 +1,13 @@
import copy
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from .strategy_generator import StrategyGenerator
class NormalPoolStrategyGenerator(StrategyGenerator):

View File

@ -1,11 +1,6 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import OutputStrategyGenerator
from typing import List
from .._utils import exception_handler
import copy
__all__ = ['OutputGenerator']
@ -46,7 +41,7 @@ class OutputGenerator(OutputStrategyGenerator):
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'Replica Output'
name = 'Replica Output'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,

View File

@ -1,11 +1,6 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import StrategyGenerator
from typing import List
from .._utils import exception_handler
import copy
__all__ = ['PlaceholderGenerator']
@ -47,7 +42,7 @@ class PlaceholderGenerator(StrategyGenerator):
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'Replica Placeholder'
name = 'Replica Placeholder'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,

View File

@ -1,11 +1,10 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
from typing import List
import copy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['ReshapeGenerator']

View File

@ -1,15 +1,16 @@
import operator
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
from functools import reduce
from abc import ABC, abstractmethod
from functools import reduce
from typing import Any, Dict, List, Union
import torch
from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
TrainCycleItem)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List, Union, Any
from ..sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem, OperationDataType
from torch.fx import Node
import copy
class StrategyGenerator(ABC):

View File

@ -1,12 +1,9 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator
from typing import List
from .._utils import exception_handler
import copy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['UnaryElementwiseGenerator']

View File

@ -1,12 +1,11 @@
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator, FollowingStrategyGenerator
from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding)
from .strategy_generator import StrategyGenerator
__all__ = ['WhereGenerator']

View File

@ -1,10 +1,11 @@
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import UnaryElementwiseGenerator, StrategyGenerator
from typing import List, Dict
from .registry import operator_registry
import operator
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
__all__ = ['UnaryElementwiseHandler']

View File

@ -1,12 +1,14 @@
import torch
from .node_handler import NodeHandler
from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector
from ..strategy import WhereGenerator, StrategyGenerator
from .broadcast import recover_sharding_spec_for_broadcast_shape
from typing import List, Dict
from .registry import operator_registry
import operator
import copy
import operator
from typing import Dict, List
import torch
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, StrategiesVector)
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, WhereGenerator
__all__ = ['WhereHandler']

View File

@ -1,17 +1,14 @@
from copy import deepcopy
from dataclasses import dataclass
from abc import ABC, abstractmethod
from enum import Enum
import operator
import torch
from functools import reduce
from typing import Any, Dict, List, Tuple, Union
from colossalai.device.device_mesh import DeviceMesh
import torch
from colossalai.tensor.shape_consistency import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from typing import Dict, List, Union, Tuple, Any
from torch.fx.node import Node
from .constants import *
from .constants import (BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP)
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
@ -75,6 +72,11 @@ class TrainCycleItem:
@dataclass
class MemoryCost:
"""
MemoryCost is a dataclass which stores the memory usage in the program.
Args:
activation (int): the memory cost incurred by the activations in bytes.
parameter (int): the memory cost incurred by the module parameter in bytes.
"""
activation: int = 0
parameter: int = 0

View File

@ -0,0 +1,7 @@
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .options import SolverOptions
from .solver import Solver
from .strategies_constructor import StrategiesConstructor
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions']

View File

@ -1,7 +1,4 @@
from typing import List
import math
from torch.fx.node import Node
from colossalai.auto_parallel.solver.constants import INFINITY_COST
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
class CostGraph:

View File

@ -1,9 +1,10 @@
from dataclasses import dataclass
from torch.fx.node import Node
from typing import List
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from collections import OrderedDict as ODict
from typing import List, OrderedDict, Union, Any
from torch.fx.node import Node
from colossalai.fx.passes.utils import get_node_module
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']

View File

@ -1,18 +1,21 @@
import warnings
import time
import numpy as np
import multiprocessing
from torch.fx.node import Node
from torch.fx.graph import Graph
from . import GraphAnalyser
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
import time
import warnings
from typing import Dict
from .constants import INFINITY_COST
import numpy as np
from torch.fx.graph import Graph
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .strategies_constructor import StrategiesConstructor
try:
import pulp
from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
warnings.warn(f'please install the pulp')
@ -21,454 +24,6 @@ __all___ = ['Solver']
class Solver:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser,
memory_budget: float = -1.0,
solution_numbers: int = 1,
memory_increasing_coefficient: float = 1.3):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
if self.solution_numbers > 1:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis()
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
# The last objective value of the best ILP solution.
self.last_objective = None
def _recover_merged_node_strategy(self):
'''
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
input_best_strategy_index = self.last_s_val[node_index - 1]
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
for strategy_index, strategy in enumerate(node.strategies_vector):
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
self.last_s_val[node_index] = strategy_index
break
def _generate_node_index_dict(self) -> Dict[Node, int]:
node_index_dict = {}
for index, strategies_vector in enumerate(self.leaf_strategies):
node_index_dict[strategies_vector.node] = index
return node_index_dict
def _prepare_data_for_solver(self):
'''
Extract information from components for solver.
'''
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
# prepare strategies_len
strategies_len = []
for node in self.nodes:
strategies_len.append(self.cost_graph.node_lens[node])
strategies_len = np.array(strategies_len)
# prepare following_nodes
following_nodes = self.cost_graph.following_dict
index_following_nodes = {}
for src, target in following_nodes.items():
src_index = self.node_index_dict[src]
target_index = self.node_index_dict[target]
index_following_nodes[src_index] = target_index
following_nodes = index_following_nodes
for index in range(node_nums):
if index not in following_nodes:
following_nodes[index] = -1
# prepare edge_pairs and resharding costs
edge_pairs = []
resharding_costs = []
for pairs, edge_cost in self.cost_graph.edge_costs.items():
src_node = pairs[0]
dst_node = pairs[1]
src_node_index = self.node_index_dict[src_node]
dst_node_index = self.node_index_dict[dst_node]
edge_pairs.append(src_node_index)
edge_pairs.append(dst_node_index)
for i in range(strategies_len[src_node_index]):
for j in range(strategies_len[dst_node_index]):
resharding_costs.append(edge_cost[(i, j)])
edge_pairs = np.array(edge_pairs)
resharding_costs = np.array(resharding_costs)
# prepare liveness_set
liveness_set = self.liveness_list
# omit alias_set now
alias_set = None
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
compute_costs = []
communication_costs = []
memory_costs = []
extra_node_costs = self.cost_graph.extra_node_costs
for strategies_vector in self.leaf_strategies:
node = strategies_vector.node
for index, strategy in enumerate(strategies_vector):
compute_costs.append(strategy.compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
# cost into
if node in extra_node_costs:
origin_communication_cost = strategy.communication_cost
extra_node_cost = extra_node_costs[node][index]
communication_cost = origin_communication_cost + extra_node_cost
communication_costs.append(communication_cost)
else:
communication_costs.append(strategy.communication_cost)
# temporarily we just consider the forward memory cost
memory_cost = strategy.memory_cost
if isinstance(memory_cost, tuple):
memory_costs.append(memory_cost[0])
else:
memory_costs.append(memory_cost)
compute_costs = np.array(compute_costs)
communication_costs = np.array(communication_costs)
memory_costs = np.array(memory_costs)
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None):
"""
Call the solver with serialized arguments.
"""
tic = time.time()
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
assert isinstance(x, np.ndarray)
assert len(strategies_len) == node_nums, "strategies_len"
def get_non_zero_index(binary_vector):
"""
Get the index of non-zero item in a vector.
"""
ct = 0
ret = None
for i, elem in enumerate(binary_vector):
if pulp.value(elem):
ret = i
ct += 1
assert ct == 1
return ret
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
######################
# omit alias set now #
######################
# A = alias_set.reshape((-1, 2)) # noqa
# for (i, j) in A:
# prod_length = strategies_len[i] * strategies_len[j]
# v.append(alias_convert_costs[pt:pt + prod_length])
# pt += prod_length
# assert pt == len(alias_convert_costs)
# L = [] # noqa
# pt = node_nums
# for i in range(node_nums):
# length = liveness_set[i]
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
v = []
pt = 0
c = []
d = []
m = []
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[pt:pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
# 1. Create variables
#############################
# create variables for node #
#############################
s = []
num_nodes = 0
reverse_follow_backpatch = []
for i in range(node_nums):
if s_follow[i] < 0:
if strategies_len[i] == 1:
s.append([1])
else:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
else:
s.append(None)
reverse_follow_backpatch.append(i)
for i in reverse_follow_backpatch:
s[i] = s[s_follow[i]]
#############################
# create variables for edge #
#############################
e = []
num_edges = 0
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
for element in s:
assert len(element) > 0
# 2. Set initial value
######################################
# set a initial value for warm start #
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
s[idx][i].fixValue()
# 3. Objective
prob = LpProblem("myProblem", LpMinimize)
###################################################################
# computing the node cost(computing cost and communication cost) #
###################################################################
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
assert len(s[i]) == len(d[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i])
prob += obj
# 4. Constraints
# (a). specified by `cat="Binary"`
# (b)
#################################################
# make sure each node only choose one strategy #
#################################################
for i in range(node_nums):
if s_follow[i] < 0:
prob += lpSum(s[i]) == 1
# (c)
#################################################
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
# (e)
prob += lpSum(e[idx]) == 1
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
######################
# omit alias set now #
######################
# alias_set = set()
# for (idx, (i, j)) in enumerate(A):
# R = len(s[i]) # noqa
# C = len(s[j]) # noqa
# if (i, j) in alias_set:
# raise ValueError(f"Duplicated edges: {(i, j)}")
# alias_set.add((i, j))
# alias_set.add((j, i))
# for row in range(len(s[i])):
# for col in range(len(s[j])):
# if v[idx][row * C + col] > 0.5:
# prob += s[i][row] + s[j][col] <= 1
verbose = True
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
prob.solve(solver)
status = prob.status
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. "
"Please increase the memory budget.")
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
for i in range(node_nums):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
self.last_s_val = list(s_val)
self._recover_merged_node_strategy()
self.last_objective = objective
if objective > INFINITY_COST:
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
return self.last_s_val, e_val, self.last_objective, status
def call_solver_serialized_args(self):
"""
Call the solver with serialized arguments and handle python errors. Additionally,
we could give a serious of solutions with different memory budget.
"""
if self.solution_numbers == 1:
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
return ret
origin_memory_budget = self.memory_budget
memory_budget_list = [
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
]
ret_list = []
for memory_budget in memory_budget_list:
self.memory_budget = memory_budget
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
ret_list.append(ret)
return ret_list
class Solver_V2:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
@ -480,7 +35,6 @@ class Solver_V2:
memory_increasing_coefficient: float = 1.3):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.

View File

@ -1,22 +1,19 @@
import math
import operator
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx import Graph, Node
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.node_handler import (OuputHandler, PlacehodlerHandler, operator_registry)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.utils import (generate_resharding_costs, generate_sharding_spec)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.auto_parallel.solver.node_handler.registry import operator_registry
from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler
from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler
from colossalai.tensor.sharding_spec import ShardingSpec
from .options import SolverOptions
from . import ShardingStrategy, StrategiesVector
from .node_handler import *
from .constants import *
from copy import deepcopy
import math
import torch
import operator
from typing import Dict, List
from ._utils import generate_sharding_spec, generate_resharding_costs
import builtins
__all__ = ['StrategiesConstructor']

View File

@ -0,0 +1,12 @@
from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape)
from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import exception_handler
from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size,
switch_partition_dim, update_partition_dim)
__all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'exception_handler', 'switch_partition_dim',
'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding',
'generate_sharding_size'
]

View File

@ -1,14 +1,17 @@
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
import torch
from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List, Optional
import operator
import warnings
from functools import reduce
import functools
import operator
from .constants import INFINITY_COST
from typing import Dict, List, Optional, Union
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx.node import Node
from ..constants import INFINITY_COST
__all__ = ['generate_sharding_spec', 'generate_resharding_costs']
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
@ -85,55 +88,3 @@ def generate_resharding_costs(nodes: List[Node],
resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def exception_handler(func):
"""
A function wrapper which executes the function with a specified seed.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
rst = func(*args, **kwargs)
return rst
except AssertionError as e:
warnings.warn(f'{e}')
return wrapper
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
dim_partition_list = []
# enumerate all the 2D sharding cases
for i in range(dim_size):
for j in range(i + 1, dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
dim_partition_list.append(dim_partition_dict_1)
for i in range(dim_size):
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
dim_partition_list.append(dim_partition_dict_flatten)
return dim_partition_list
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
dim_partition_list = []
# enumerate all the 1D sharding cases
for i in range(dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
return dim_partition_list
def generate_sharding_size(dim_partition_dict, device_mesh):
total_sharding_size = 1
for mesh_dim_list in dim_partition_dict.values():
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
total_sharding_size *= sharding_size
return total_sharding_size

View File

@ -0,0 +1,26 @@
import functools
import warnings
__all__ = ['exception_handler']
def exception_handler(func):
"""
A function wrapper to handle the AssertionError in the function.
Usage:
# mute the assertion error in the function
@exception_handler
def do_something():
...
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
rst = func(*args, **kwargs)
return rst
except AssertionError as e:
warnings.warn(f'{e}')
return wrapper

View File

@ -1,7 +1,16 @@
import torch
from typing import Dict
from colossalai.tensor.sharding_spec import ShardingSpec
import operator
from copy import deepcopy
from functools import reduce
from typing import Dict
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
]
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
@ -66,3 +75,39 @@ def update_partition_dim(sharding_spec: ShardingSpec,
entire_shape=physical_shape,
dim_partition_dict=new_dim_partition_dict)
return current_sharding_spec
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
dim_partition_list = []
# enumerate all the 2D sharding cases
for i in range(dim_size):
for j in range(i + 1, dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
dim_partition_list.append(dim_partition_dict_1)
for i in range(dim_size):
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
dim_partition_list.append(dim_partition_dict_flatten)
return dim_partition_list
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
dim_partition_list = []
# enumerate all the 1D sharding cases
for i in range(dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
return dim_partition_list
def generate_sharding_size(dim_partition_dict, device_mesh):
total_sharding_size = 1
for mesh_dim_list in dim_partition_dict.values():
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
total_sharding_size *= sharding_size
return total_sharding_size

View File

@ -1,7 +1,9 @@
import torch
from colossalai.auto_parallel.solver.node_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable,
recover_sharding_spec_for_broadcast_shape)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
def test_is_broadcastable():

View File

@ -1,7 +1,8 @@
import torch.nn as nn
import torch
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.fx import ColoTracer, ColoGraphModule
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
from colossalai.fx import ColoGraphModule, ColoTracer
class LinearModel(nn.Module):

View File

@ -1,10 +1,12 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.batch_norm_handler import BatchNormModuleHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import \
BatchNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
def test_bn_module_handler():

View File

@ -1,10 +1,12 @@
import pytest
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import \
BMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag

View File

@ -1,10 +1,11 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvModuleHandler, ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import (ConvFunctionHandler, ConvModuleHandler)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
def test_conv_module_handler():

View File

@ -1,11 +1,14 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.getitem_handler import GetItemHandler
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import \
GetItemHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class GetItemModel(nn.Module):

View File

@ -1,10 +1,12 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.layer_norm_handler import LayerNormModuleHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import \
LayerNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
def test_ln_module_handler():

View File

@ -1,10 +1,12 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.dot_handler import LinearModuleHandler, LinearFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import (LinearFunctionHandler, LinearModuleHandler)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.tensor.sharding_spec import ShardingSpec

View File

@ -1,11 +1,13 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import pytest
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \
NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
import pytest
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag

View File

@ -1,9 +1,11 @@
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \
OuputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class OutputModel(nn.Module):

View File

@ -1,9 +1,11 @@
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \
PlacehodlerHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class PlaceholderModel(nn.Module):

View File

@ -1,10 +1,13 @@
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.solver.node_handler.reshape_handler import ReshapeHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import \
ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class ReshapeModel(nn.Module):

View File

@ -1,11 +1,14 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \
UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class ReLuModel(nn.Module):

View File

@ -1,10 +1,12 @@
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.where_handler import WhereHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \
WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class ConvModel(nn.Module):

View File

@ -1,24 +1,22 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from torch.fx import GraphModule
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
StrategiesConstructor)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
from colossalai.auto_parallel.solver.solver import Solver_V2
from colossalai.auto_parallel.solver.options import SolverOptions
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (shape_consistency_pass,
solution_annotatation_pass)
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
class ConvModel(nn.Module):
@ -61,7 +59,7 @@ def check_apply(rank, world_size, port):
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()

View File

@ -1,20 +1,13 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from torchvision.models import resnet50
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
StrategiesConstructor)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from copy import deepcopy
from colossalai.auto_parallel.solver.solver import Solver
from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.solver.constants import *
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.options import SolverOptions
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing.pytest_wrapper import run_on_environment_flag