[autoparallel] handled illegal strategy in node handler (#1743)

* [autoparallel] handled illegal strategy in node handler

* polish code
pull/1744/head
Frank Lee 2022-10-19 17:08:52 +08:00 committed by GitHub
parent 30874f1692
commit 88a79814fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 89 additions and 82 deletions

View File

@ -1,18 +1,20 @@
import warnings
import time
import numpy as np
import multiprocessing
from torch.fx.node import Node
from torch.fx.graph import Graph
from .graph_analysis import GraphAnalyser
from .cost_graph import CostGraph
from .strategies_constructor import StrategiesConstructor
import time
import warnings
from typing import Dict
import numpy as np
from torch.fx.graph import Graph
from torch.fx.node import Node
from .constants import INFINITY_COST
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .strategies_constructor import StrategiesConstructor
try:
import pulp
from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
warnings.warn(f'please install the pulp')

View File

@ -1,10 +1,16 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, ShardingStrategy, StrategiesVector,
TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
@ -98,6 +104,12 @@ class NodeHandler(ABC):
self.strategies_vector.extend(post_processed_strategies)
# validating the correctness of the sharding strategy
for strategy in self.strategies_vector:
for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
check_sharding_spec_validity(sharding_spec, op_data.data)
return self.strategies_vector
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
@ -116,8 +128,8 @@ class NodeHandler(ABC):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
"""
Returns the mapping between the logical operation data to its physical data.
A logical operation data is a data associated with an operation, which can be input and output. It is
defined by the strategy generator, for example, a matrix multiplication operation has two operands "input"
A logical operation data is a data associated with an operation, which can be input and output. It is
defined by the strategy generator, for example, a matrix multiplication operation has two operands "input"
and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is
the module input, the physical operand for "other" is the module weight, and the physical result for "output"
is the module output.

View File

@ -3,7 +3,7 @@ import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
@ -31,8 +31,8 @@ class BatchNormStrategyGenerator(StrategyGenerator):
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy):
'''

View File

@ -1,12 +1,17 @@
from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape)
from .broadcast import BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape
from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import ignore_sharding_exception
from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size,
switch_partition_dim, update_partition_dim)
from .misc import check_sharding_spec_validity, ignore_sharding_exception
from .sharding import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size,
switch_partition_dim,
update_partition_dim,
)
__all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'switch_partition_dim',
'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding',
'generate_sharding_size'
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
]

View File

@ -1,7 +1,9 @@
import functools
import torch
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpecException
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
__all__ = ['ignore_sharding_exception']
@ -29,3 +31,37 @@ def ignore_sharding_exception(func):
return None
return wrapper
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
#
"""
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
if str(dim_spec).startswith('S'):
devices_str = str(dim_spec).lstrip('S')
num_devices = 1
if '0' in devices_str:
num_devices *= num_devices_in_col
if '1' in devices_str:
num_devices *= num_devices_in_row
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'

View File

@ -1,37 +0,0 @@
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
def is_sharding_spec_valid(sharding_spec: ShardingSpec, tensor: torch.Tensor):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
#
"""
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
if str(dim_spec).startswith('S'):
devices_str = str(dim_spec).lstrip('S')
num_devices = 1
if '0' in devices_str:
num_devices *= num_devices_in_col
if '1' in devices_str:
num_devices *= num_devices_in_row
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'

View File

@ -1,11 +1,10 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import (ConvFunctionHandler, ConvModuleHandler)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
def test_conv_module_handler():

View File

@ -1,13 +1,15 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import (LinearFunctionHandler, LinearModuleHandler)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
StrategiesVector)
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.common import \
is_sharding_spec_valid
def test_linear_module_handler():
@ -92,12 +94,6 @@ def test_linear_module_handler():
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
# make sure the sharding spec is valid
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('0.weight'))
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('0.bias'))
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
@ -182,12 +178,6 @@ def test_linear_function_handler():
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
# make sure the sharding spec is valid
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('weight'))
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('bias'))
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]