mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] autoparallel unit test (#1752)
parent
a4ce180e85
commit
cdb7d5e7d2
|
@ -1,14 +1,15 @@
|
||||||
from .operator_handler import OperatorHandler
|
|
||||||
from .dot_handler import DotHandler
|
|
||||||
from .conv_handler import ConvHandler
|
|
||||||
from .batch_norm_handler import BatchNormHandler
|
from .batch_norm_handler import BatchNormHandler
|
||||||
from .reshape_handler import ReshapeHandler
|
|
||||||
from .bcast_op_handler import BcastOpHandler
|
from .bcast_op_handler import BcastOpHandler
|
||||||
|
from .conv_handler import ConvHandler
|
||||||
|
from .dot_handler import DotHandler
|
||||||
from .embedding_handler import EmbeddingHandler
|
from .embedding_handler import EmbeddingHandler
|
||||||
|
from .layer_norm_handler import LayerNormHandler
|
||||||
|
from .operator_handler import OperatorHandler
|
||||||
|
from .reshape_handler import ReshapeHandler
|
||||||
from .unary_elementwise_handler import UnaryElementwiseHandler
|
from .unary_elementwise_handler import UnaryElementwiseHandler
|
||||||
from .where_handler import WhereHandler
|
from .where_handler import WhereHandler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
|
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
|
||||||
'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler'
|
'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler', 'LayerNormHandler'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,17 +1,18 @@
|
||||||
import torch
|
from copy import deepcopy
|
||||||
from torch.fx import GraphModule
|
|
||||||
import torch.nn as nn
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
import pytest
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
import torch
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
import torch.nn as nn
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||||
from copy import deepcopy
|
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
|
||||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||||
|
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,7 +61,7 @@ def test_solver():
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions(fast=True)
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||||
|
|
Loading…
Reference in New Issue