mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fix spelling error (#2270)
parent
af32022f74
commit
fb87322773
|
@ -6,9 +6,9 @@ from typing import List
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
|
||||
from .operator_handler import OperatorHandler
|
||||
|
@ -82,13 +82,13 @@ class MatVecStrategyGenerator(StrategyGenerator):
|
|||
|
||||
class MatMulStrategyGenerator(StrategyGenerator):
|
||||
"""
|
||||
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
|
||||
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
|
||||
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
|
||||
|
||||
A matmul can be formulated as [n, p] x [p, q] = [n, q]
|
||||
|
||||
Args:
|
||||
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
|
||||
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
|
||||
This will incur extra transformation of the dim partitioning as the weight is transposed.
|
||||
"""
|
||||
|
||||
|
@ -255,7 +255,7 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
|||
"""
|
||||
Generate sharding strategies for the batched matrix multiplication.
|
||||
|
||||
A batched matrix multiplication can be viewed as
|
||||
A batched matrix multiplication can be viewed as
|
||||
[b, i, k] x [b, k, j] -> [b, i, j]
|
||||
"""
|
||||
|
||||
|
@ -431,7 +431,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -451,7 +451,7 @@ class DotHandler(OperatorHandler):
|
|||
|
||||
# create and register strategy
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -473,7 +473,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -491,7 +491,7 @@ class DotHandler(OperatorHandler):
|
|||
communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
|
||||
communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -510,7 +510,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -529,7 +529,7 @@ class DotHandler(OperatorHandler):
|
|||
communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -548,7 +548,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -564,7 +564,7 @@ class DotHandler(OperatorHandler):
|
|||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -583,7 +583,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -600,7 +600,7 @@ class DotHandler(OperatorHandler):
|
|||
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
|
||||
communication_cost = communication_cost_activation_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -619,7 +619,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -636,7 +636,7 @@ class DotHandler(OperatorHandler):
|
|||
communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
|
||||
communication_cost = communication_cost_weight_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -655,7 +655,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -673,7 +673,7 @@ class DotHandler(OperatorHandler):
|
|||
activation_memory_cost, 0)
|
||||
communication_cost = communication_cost_forward_activation
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
@ -692,7 +692,7 @@ class DotHandler(OperatorHandler):
|
|||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
|
||||
|
@ -709,7 +709,7 @@ class DotHandler(OperatorHandler):
|
|||
input_grad_memory_cost, 0)
|
||||
communication_cost = communication_cost_activation_backward
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=toatl_memory_cost,
|
||||
|
|
|
@ -5,14 +5,14 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
|
|||
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||
from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler
|
||||
from .experimental import PermuteHandler, ViewHandler
|
||||
from .getatrr_handler import GetattrHandler
|
||||
from .getattr_handler import GetattrHandler
|
||||
from .getitem_handler import GetItemHandler
|
||||
from .layer_norm_handler import LayerNormModuleHandler
|
||||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .matmul_handler import MatMulHandler
|
||||
from .normal_pooling_handler import NormPoolingHandler
|
||||
from .output_handler import OuputHandler
|
||||
from .placeholder_handler import PlacehodlerHandler
|
||||
from .output_handler import OutputHandler
|
||||
from .placeholder_handler import PlaceholderHandler
|
||||
from .registry import operator_registry
|
||||
from .reshape_handler import ReshapeHandler
|
||||
from .softmax_handler import SoftmaxHandler
|
||||
|
@ -24,7 +24,7 @@ from .where_handler import WhereHandler
|
|||
__all__ = [
|
||||
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
|
||||
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
||||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
||||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
|
||||
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
|
||||
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler'
|
||||
|
|
|
@ -8,12 +8,12 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
|
|||
from .node_handler import NodeHandler
|
||||
from .strategy import OutputGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['OuputHandler']
|
||||
__all__ = ['OutputHandler']
|
||||
|
||||
|
||||
class OuputHandler(NodeHandler):
|
||||
class OutputHandler(NodeHandler):
|
||||
"""
|
||||
A OuputHandler which deals with the sharding strategies for Output Node.
|
||||
A OutputHandler which deals with the sharding strategies for Output Node.
|
||||
"""
|
||||
|
||||
def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||
|
|
|
@ -8,12 +8,12 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
|
|||
from .node_handler import NodeHandler
|
||||
from .strategy import PlaceholderGenerator, StrategyGenerator
|
||||
|
||||
__all__ = ['PlacehodlerHandler']
|
||||
__all__ = ['PlaceholderHandler']
|
||||
|
||||
|
||||
class PlacehodlerHandler(NodeHandler):
|
||||
class PlaceholderHandler(NodeHandler):
|
||||
"""
|
||||
A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node.
|
||||
A PlaceholderHandler which deals with the sharding strategies for Placeholder Node.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||
|
|
|
@ -9,8 +9,8 @@ from torch.fx import Graph, Node
|
|||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import (
|
||||
GetattrHandler,
|
||||
OuputHandler,
|
||||
PlacehodlerHandler,
|
||||
OutputHandler,
|
||||
PlaceholderHandler,
|
||||
operator_registry,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||
|
@ -93,7 +93,7 @@ class StrategiesConstructor:
|
|||
else:
|
||||
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
|
||||
placeholder_option = 'replicated'
|
||||
placeholder_handler = PlacehodlerHandler(node,
|
||||
placeholder_handler = PlaceholderHandler(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
placeholder_option=placeholder_option)
|
||||
|
@ -140,7 +140,7 @@ class StrategiesConstructor:
|
|||
else:
|
||||
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
|
||||
output_option = 'replicated'
|
||||
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
|
||||
output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
|
||||
output_handler.register_strategy()
|
||||
|
||||
self.remove_duplicated_strategy(strategies_vector)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@ -145,7 +145,7 @@ def test_getitem_from_tuple_handler():
|
|||
split_strategies_vector = StrategiesVector(split_node)
|
||||
|
||||
# build handler
|
||||
input_handler = PlacehodlerHandler(
|
||||
input_handler = PlaceholderHandler(
|
||||
node=input_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=input_strategies_vector,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OuputHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
@ -39,10 +39,10 @@ def test_output_handler(output_option):
|
|||
output_strategies_vector = StrategiesVector(output_node)
|
||||
|
||||
# build handler
|
||||
otuput_handler = OuputHandler(node=output_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=output_strategies_vector,
|
||||
output_option=output_option)
|
||||
otuput_handler = OutputHandler(node=output_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=output_strategies_vector,
|
||||
output_option=output_option)
|
||||
|
||||
otuput_handler.register_strategy(compute_resharding_cost=False)
|
||||
# check operation data mapping
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
@ -36,7 +36,7 @@ def test_placeholder_handler(placeholder_option):
|
|||
placeholder_node = list(graph.nodes)[0]
|
||||
placeholder_strategies_vector = StrategiesVector(placeholder_node)
|
||||
# build handler
|
||||
placeholder_handler = PlacehodlerHandler(node=placeholder_node,
|
||||
placeholder_handler = PlaceholderHandler(node=placeholder_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=placeholder_strategies_vector,
|
||||
placeholder_option=placeholder_option)
|
||||
|
|
Loading…
Reference in New Issue