mirror of https://github.com/hpcaitech/ColossalAI
parent
5b24987fa7
commit
21d6a48f4d
|
@ -8,14 +8,9 @@ from torch.fx.graph import Graph
|
|||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
Solver,
|
||||
SolverOptions,
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
|
@ -69,13 +64,43 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f
|
|||
pass
|
||||
|
||||
|
||||
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
|
||||
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
|
||||
shard_option: str):
|
||||
'''
|
||||
This method is used to build the strategy_constructor for the given graph.
|
||||
After this method, each node in the graph will have a strategies_vector which
|
||||
is constructed by the related node handler.
|
||||
'''
|
||||
solver_options = SolverOptions()
|
||||
if solver_preference == 'standard':
|
||||
solver_preference = SolverPerference.STANDARD
|
||||
elif solver_preference == 'tp':
|
||||
solver_preference = SolverPerference.TP
|
||||
elif solver_preference == 'dp':
|
||||
solver_preference = SolverPerference.DP
|
||||
else:
|
||||
raise ValueError(f'Invalid solver_preference: {solver_preference}')
|
||||
|
||||
if dataloader_option == 'replicated':
|
||||
dataloader_option = DataloaderOption.REPLICATED
|
||||
elif dataloader_option == 'distributed':
|
||||
dataloader_option = DataloaderOption.DISTRIBUTED
|
||||
else:
|
||||
raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
|
||||
|
||||
if shard_option == 'standard':
|
||||
shard_option = ShardOption.STANDARD
|
||||
elif shard_option == 'shard':
|
||||
shard_option = ShardOption.SHARD
|
||||
elif shard_option == 'shard_last_axis':
|
||||
shard_option = ShardOption.SHARD_LAST_AXIS
|
||||
elif shard_option == 'full_shard':
|
||||
shard_option = ShardOption.FULL_SHARD
|
||||
else:
|
||||
raise ValueError(f'Invalid shard_option: {shard_option}')
|
||||
|
||||
solver_options = SolverOptions(solver_perference=solver_preference,
|
||||
dataloader_option=dataloader_option,
|
||||
shard_option=shard_option)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
|
@ -183,6 +208,9 @@ def initialize_model(model: nn.Module,
|
|||
device_mesh: DeviceMesh,
|
||||
memory_budget: float = -1.0,
|
||||
overlap: bool = False,
|
||||
solver_preference: str = 'standard',
|
||||
dataloader_option: str = 'replicated',
|
||||
shard_option: str = 'standard',
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solution_path: str = None,
|
||||
|
@ -198,6 +226,12 @@ def initialize_model(model: nn.Module,
|
|||
the memory budget will be infinity.
|
||||
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
|
||||
backward computing.
|
||||
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
|
||||
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
|
||||
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
|
||||
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
|
||||
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
|
||||
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
|
@ -212,7 +246,12 @@ def initialize_model(model: nn.Module,
|
|||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
strategies_constructor = build_strategy_constructor(graph, device_mesh)
|
||||
|
||||
strategies_constructor = build_strategy_constructor(graph,
|
||||
device_mesh,
|
||||
solver_preference=solver_preference,
|
||||
dataloader_option=dataloader_option,
|
||||
shard_option=shard_option)
|
||||
if load_solver_solution:
|
||||
solution = torch.load(solution_path)
|
||||
else:
|
||||
|
@ -240,6 +279,9 @@ def autoparallelize(model: nn.Module,
|
|||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||
logical_mesh_shape: Tuple[int] = None,
|
||||
logical_mesh_id: torch.Tensor = None,
|
||||
solver_preference: str = 'standard',
|
||||
dataloader_option: str = 'replicated',
|
||||
shard_option: str = 'standard',
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solver_solution_path: str = None,
|
||||
|
@ -262,6 +304,12 @@ def autoparallelize(model: nn.Module,
|
|||
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||
generated by search_best_logical_mesh_shape function.
|
||||
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
|
||||
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
|
||||
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
|
||||
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
|
||||
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
|
||||
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
|
||||
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
|
@ -280,6 +328,8 @@ def autoparallelize(model: nn.Module,
|
|||
rst_to_unpack = initialize_model(model,
|
||||
meta_args,
|
||||
device_mesh,
|
||||
solver_preference=solver_preference,
|
||||
dataloader_option=dataloader_option,
|
||||
save_solver_solution=save_solver_solution,
|
||||
load_solver_solution=load_solver_solution,
|
||||
solution_path=solver_solution_path,
|
||||
|
|
|
@ -11,7 +11,6 @@ from .layer_norm_handler import LayerNormModuleHandler
|
|||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .matmul_handler import MatMulHandler
|
||||
from .normal_pooling_handler import NormPoolingHandler
|
||||
from .option import ShardOption
|
||||
from .output_handler import OutputHandler
|
||||
from .permute_handler import PermuteHandler
|
||||
from .placeholder_handler import PlaceholderHandler
|
||||
|
@ -31,6 +30,6 @@ __all__ = [
|
|||
'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
|
||||
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
|
||||
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption',
|
||||
'TransposeHandler', 'SplitHandler'
|
||||
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
|
||||
'SplitHandler'
|
||||
]
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
|
||||
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
|
@ -32,19 +32,19 @@ class NodeHandler(ABC):
|
|||
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node: Node,
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_vector: StrategiesVector,
|
||||
shard_option: ShardOption = ShardOption.STANDARD,
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
node: Node,
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_vector: StrategiesVector,
|
||||
shard_option: ShardOption = ShardOption.STANDARD,
|
||||
solver_perference: SolverPerference = SolverPerference.STANDARD) -> None:
|
||||
self.node = node
|
||||
self.predecessor_node = list(node._input_nodes.keys())
|
||||
self.successor_node = list(node.users.keys())
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.shard_option = shard_option
|
||||
self.solver_perference = solver_perference
|
||||
|
||||
def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
|
||||
"""
|
||||
|
@ -187,15 +187,24 @@ class NodeHandler(ABC):
|
|||
|
||||
remove_strategy_list = []
|
||||
for strategy in self.strategies_vector:
|
||||
shard_level = 0
|
||||
shard_axis_list = []
|
||||
last_axis = len(self.device_mesh.mesh_shape) - 1
|
||||
for op_data, sharding_spec in strategy.sharding_specs.items():
|
||||
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
|
||||
for dim, shard_axis in sharding_spec.dim_partition_dict.items():
|
||||
shard_level += len(shard_axis)
|
||||
for dim, shard_axes in sharding_spec.dim_partition_dict.items():
|
||||
for shard_axis in shard_axes:
|
||||
if shard_axis not in shard_axis_list:
|
||||
shard_axis_list.append(shard_axis)
|
||||
|
||||
shard_level = len(shard_axis_list)
|
||||
using_last_axis = last_axis in shard_axis_list or -1 in shard_axis_list
|
||||
if self.shard_option == ShardOption.SHARD and shard_level == 0:
|
||||
remove_strategy_list.append(strategy)
|
||||
if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1:
|
||||
remove_strategy_list.append(strategy)
|
||||
if self.shard_option == ShardOption.SHARD_LAST_AXIS:
|
||||
if shard_level != 1 or using_last_axis == False:
|
||||
remove_strategy_list.append(strategy)
|
||||
|
||||
for strategy in remove_strategy_list:
|
||||
self.strategies_vector.remove(strategy)
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
from enum import Enum
|
||||
|
||||
__all__ = ['ShardOption']
|
||||
|
||||
|
||||
class ShardOption(Enum):
|
||||
"""
|
||||
This enum class is to define the shard level required in node strategies.
|
||||
|
||||
Notes:
|
||||
STANDARD: We do not add any extra shard requirements.
|
||||
SHARD: We require the node to be shard using at least one device mesh axis.
|
||||
FULL_SHARD: We require the node to be shard using all device mesh axes.
|
||||
"""
|
||||
STANDARD = 0
|
||||
SHARD = 1
|
||||
FULL_SHARD = 2
|
|
@ -0,0 +1,49 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption']
|
||||
|
||||
|
||||
class SolverPerference(Enum):
|
||||
"""
|
||||
This enum class is to define the solver preference.
|
||||
"""
|
||||
STANDARD = 0
|
||||
DP = 1
|
||||
TP = 2
|
||||
|
||||
|
||||
class ShardOption(Enum):
|
||||
"""
|
||||
This enum class is to define the shard level required in node strategies.
|
||||
|
||||
Notes:
|
||||
STANDARD: We do not add any extra shard requirements.
|
||||
SHARD: We require the node to be shard using at least one device mesh axis.
|
||||
SHARD_ONE_AXIS: We require the node to be shard using the last device mesh axis.
|
||||
FULL_SHARD: We require the node to be shard using all device mesh axes.
|
||||
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
|
||||
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
|
||||
"""
|
||||
STANDARD = 0
|
||||
SHARD = 1
|
||||
SHARD_LAST_AXIS = 2
|
||||
FULL_SHARD = 3
|
||||
|
||||
|
||||
class DataloaderOption(Enum):
|
||||
"""
|
||||
This enum class is to define the dataloader option.
|
||||
"""
|
||||
REPLICATED = 0
|
||||
DISTRIBUTED = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolverOptions:
|
||||
"""
|
||||
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
|
||||
"""
|
||||
solver_perference: SolverPerference = SolverPerference.STANDARD
|
||||
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
|
||||
shard_option: ShardOption = ShardOption.STANDARD
|
|
@ -1,7 +1,6 @@
|
|||
from .cost_graph import CostGraph
|
||||
from .graph_analysis import GraphAnalyser
|
||||
from .options import SolverOptions
|
||||
from .solver import Solver
|
||||
from .strategies_constructor import StrategiesConstructor
|
||||
|
||||
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions']
|
||||
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
__all__ = ['SolverOptions']
|
||||
|
||||
|
||||
class SolverPerference(Enum):
|
||||
"""
|
||||
This enum class is to define the solver preference.
|
||||
"""
|
||||
STANDARD = 0
|
||||
DP = 1
|
||||
TP = 2
|
||||
|
||||
|
||||
class DataloaderOption(Enum):
|
||||
"""
|
||||
This enum class is to define the dataloader option.
|
||||
"""
|
||||
REPLICATED = 0
|
||||
DISTRIBUTED = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolverOptions:
|
||||
"""
|
||||
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
|
||||
"""
|
||||
solver_perference: SolverPerference = SolverPerference.STANDARD
|
||||
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
|
|
@ -33,7 +33,7 @@ class Solver:
|
|||
solution_numbers: int = 1,
|
||||
forward_only: bool = False,
|
||||
memory_increasing_coefficient: float = 1.3,
|
||||
verbose=True):
|
||||
verbose=False):
|
||||
'''
|
||||
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
||||
Argument:
|
||||
|
|
|
@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVe
|
|||
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .options import DataloaderOption, SolverOptions
|
||||
from ..options import DataloaderOption, SolverOptions
|
||||
|
||||
__all__ = ['StrategiesConstructor']
|
||||
|
||||
|
@ -101,7 +101,11 @@ class StrategiesConstructor:
|
|||
|
||||
# get_attr node
|
||||
elif node.op == 'get_attr':
|
||||
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
|
||||
getattr_handler = GetattrHandler(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
getattr_handler.register_strategy()
|
||||
|
||||
# call_module node
|
||||
|
@ -109,7 +113,11 @@ class StrategiesConstructor:
|
|||
target = node.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
|
||||
handler = operator_registry.get(submod_type)(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
|
@ -118,7 +126,11 @@ class StrategiesConstructor:
|
|||
# call_function node
|
||||
elif node.op == 'call_function':
|
||||
target = node.target
|
||||
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
|
||||
handler = operator_registry.get(target)(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
|
@ -127,7 +139,11 @@ class StrategiesConstructor:
|
|||
# call_method node
|
||||
elif node.op == 'call_method':
|
||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
|
||||
handler = operator_registry.get(method)(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
shard_option=self.solver_options.shard_option,
|
||||
solver_perference=self.solver_options.solver_perference)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
|
|
|
@ -4,13 +4,8 @@ import transformers
|
|||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
Solver,
|
||||
SolverOptions,
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
|
|
|
@ -7,8 +7,9 @@ from torch.fx import GraphModule
|
|||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.multiprocessing as mp
|
|||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
|
||||
from colossalai.auto_parallel.tensor_shard.options import ShardOption
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
@ -49,6 +49,15 @@ def check_shard_option(shard_option):
|
|||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
|
||||
if shard_option == ShardOption.SHARD_LAST_AXIS:
|
||||
# RR = RS x SR
|
||||
assert 'RR = RS1 x S1R' in strategy_name_list
|
||||
|
||||
# RS= RR x RS
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
return
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
|
||||
|
@ -104,7 +113,8 @@ def check_shard_option(shard_option):
|
|||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_shard_option():
|
||||
for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD]:
|
||||
# for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]:
|
||||
for shard_option in [ShardOption.SHARD_LAST_AXIS]:
|
||||
check_shard_option(shard_option)
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,8 @@ from torch.fx import GraphModule
|
|||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
|
|
|
@ -1,13 +1,8 @@
|
|||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
Solver,
|
||||
SolverOptions,
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
|
|
@ -3,13 +3,8 @@ from torch.fx import GraphModule
|
|||
from torchvision.models import resnet50
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
Solver,
|
||||
SolverOptions,
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
|
|
Loading…
Reference in New Issue