[autoparallel] fix spelling error (#2270)

pull/2277/head
YuliangLiu0306 2023-01-03 16:13:00 +08:00 committed by GitHub
parent af32022f74
commit fb87322773
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 46 additions and 46 deletions

View File

@ -6,9 +6,9 @@ from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
@ -82,13 +82,13 @@ class MatVecStrategyGenerator(StrategyGenerator):
class MatMulStrategyGenerator(StrategyGenerator): class MatMulStrategyGenerator(StrategyGenerator):
""" """
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm. a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
A matmul can be formulated as [n, p] x [p, q] = [n, q] A matmul can be formulated as [n, p] x [p, q] = [n, q]
Args: Args:
is_linear (bool): whether this generator is used for nn.Linear and F.linear. is_linear (bool): whether this generator is used for nn.Linear and F.linear.
This will incur extra transformation of the dim partitioning as the weight is transposed. This will incur extra transformation of the dim partitioning as the weight is transposed.
""" """
@ -255,7 +255,7 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator):
""" """
Generate sharding strategies for the batched matrix multiplication. Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j] [b, i, k] x [b, k, j] -> [b, i, j]
""" """
@ -431,7 +431,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -451,7 +451,7 @@ class DotHandler(OperatorHandler):
# create and register strategy # create and register strategy
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
@ -473,7 +473,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]} dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -491,7 +491,7 @@ class DotHandler(OperatorHandler):
communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0) communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
communication_cost = communication_cost_activation_forward + communication_cost_grad_backward communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
@ -510,7 +510,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]} dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -529,7 +529,7 @@ class DotHandler(OperatorHandler):
communication_cost = communication_cost_activation_backward + communication_cost_activation_forward communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
@ -548,7 +548,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -564,7 +564,7 @@ class DotHandler(OperatorHandler):
# compute the communication cost of this strategy # compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim) communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
@ -583,7 +583,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim]} dim_partition_dict_for_output = {1: [mesh_dim]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -600,7 +600,7 @@ class DotHandler(OperatorHandler):
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim) communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
communication_cost = communication_cost_activation_backward communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
@ -619,7 +619,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -636,7 +636,7 @@ class DotHandler(OperatorHandler):
communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0) communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
communication_cost = communication_cost_weight_backward communication_cost = communication_cost_weight_backward
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
@ -655,7 +655,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -673,7 +673,7 @@ class DotHandler(OperatorHandler):
activation_memory_cost, 0) activation_memory_cost, 0)
communication_cost = communication_cost_forward_activation communication_cost = communication_cost_forward_activation
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
@ -692,7 +692,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
@ -709,7 +709,7 @@ class DotHandler(OperatorHandler):
input_grad_memory_cost, 0) input_grad_memory_cost, 0)
communication_cost = communication_cost_activation_backward communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,

View File

@ -5,14 +5,14 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler
from .experimental import PermuteHandler, ViewHandler from .experimental import PermuteHandler, ViewHandler
from .getatrr_handler import GetattrHandler from .getattr_handler import GetattrHandler
from .getitem_handler import GetItemHandler from .getitem_handler import GetItemHandler
from .layer_norm_handler import LayerNormModuleHandler from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OuputHandler from .output_handler import OutputHandler
from .placeholder_handler import PlacehodlerHandler from .placeholder_handler import PlaceholderHandler
from .registry import operator_registry from .registry import operator_registry
from .reshape_handler import ReshapeHandler from .reshape_handler import ReshapeHandler
from .softmax_handler import SoftmaxHandler from .softmax_handler import SoftmaxHandler
@ -24,7 +24,7 @@ from .where_handler import WhereHandler
__all__ = [ __all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler' 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler'

View File

@ -8,12 +8,12 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .strategy import OutputGenerator, StrategyGenerator from .strategy import OutputGenerator, StrategyGenerator
__all__ = ['OuputHandler'] __all__ = ['OutputHandler']
class OuputHandler(NodeHandler): class OutputHandler(NodeHandler):
""" """
A OuputHandler which deals with the sharding strategies for Output Node. A OutputHandler which deals with the sharding strategies for Output Node.
""" """
def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,

View File

@ -8,12 +8,12 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .strategy import PlaceholderGenerator, StrategyGenerator from .strategy import PlaceholderGenerator, StrategyGenerator
__all__ = ['PlacehodlerHandler'] __all__ = ['PlaceholderHandler']
class PlacehodlerHandler(NodeHandler): class PlaceholderHandler(NodeHandler):
""" """
A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node. A PlaceholderHandler which deals with the sharding strategies for Placeholder Node.
""" """
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,

View File

@ -9,8 +9,8 @@ from torch.fx import Graph, Node
from colossalai.auto_parallel.tensor_shard.node_handler import ( from colossalai.auto_parallel.tensor_shard.node_handler import (
GetattrHandler, GetattrHandler,
OuputHandler, OutputHandler,
PlacehodlerHandler, PlaceholderHandler,
operator_registry, operator_registry,
) )
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
@ -93,7 +93,7 @@ class StrategiesConstructor:
else: else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
placeholder_option = 'replicated' placeholder_option = 'replicated'
placeholder_handler = PlacehodlerHandler(node, placeholder_handler = PlaceholderHandler(node,
self.device_mesh, self.device_mesh,
strategies_vector, strategies_vector,
placeholder_option=placeholder_option) placeholder_option=placeholder_option)
@ -140,7 +140,7 @@ class StrategiesConstructor:
else: else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
output_option = 'replicated' output_option = 'replicated'
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy() output_handler.register_strategy()
self.remove_duplicated_strategy(strategies_vector) self.remove_duplicated_strategy(strategies_vector)

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
@ -145,7 +145,7 @@ def test_getitem_from_tuple_handler():
split_strategies_vector = StrategiesVector(split_node) split_strategies_vector = StrategiesVector(split_node)
# build handler # build handler
input_handler = PlacehodlerHandler( input_handler = PlaceholderHandler(
node=input_node, node=input_node,
device_mesh=device_mesh, device_mesh=device_mesh,
strategies_vector=input_strategies_vector, strategies_vector=input_strategies_vector,

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OuputHandler from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
@ -39,10 +39,10 @@ def test_output_handler(output_option):
output_strategies_vector = StrategiesVector(output_node) output_strategies_vector = StrategiesVector(output_node)
# build handler # build handler
otuput_handler = OuputHandler(node=output_node, otuput_handler = OutputHandler(node=output_node,
device_mesh=device_mesh, device_mesh=device_mesh,
strategies_vector=output_strategies_vector, strategies_vector=output_strategies_vector,
output_option=output_option) output_option=output_option)
otuput_handler.register_strategy(compute_resharding_cost=False) otuput_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping # check operation data mapping

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
@ -36,7 +36,7 @@ def test_placeholder_handler(placeholder_option):
placeholder_node = list(graph.nodes)[0] placeholder_node = list(graph.nodes)[0]
placeholder_strategies_vector = StrategiesVector(placeholder_node) placeholder_strategies_vector = StrategiesVector(placeholder_node)
# build handler # build handler
placeholder_handler = PlacehodlerHandler(node=placeholder_node, placeholder_handler = PlaceholderHandler(node=placeholder_node,
device_mesh=device_mesh, device_mesh=device_mesh,
strategies_vector=placeholder_strategies_vector, strategies_vector=placeholder_strategies_vector,
placeholder_option=placeholder_option) placeholder_option=placeholder_option)