mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] handled illegal strategy in node handler (#1743)
* [autoparallel] handled illegal strategy in node handler * polish codepull/1744/head
parent
30874f1692
commit
88a79814fb
|
@ -1,18 +1,20 @@
|
||||||
import warnings
|
|
||||||
|
|
||||||
import time
|
|
||||||
import numpy as np
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from torch.fx.node import Node
|
import time
|
||||||
from torch.fx.graph import Graph
|
import warnings
|
||||||
from .graph_analysis import GraphAnalyser
|
|
||||||
from .cost_graph import CostGraph
|
|
||||||
from .strategies_constructor import StrategiesConstructor
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from torch.fx.graph import Graph
|
||||||
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from .constants import INFINITY_COST
|
from .constants import INFINITY_COST
|
||||||
|
from .cost_graph import CostGraph
|
||||||
|
from .graph_analysis import GraphAnalyser
|
||||||
|
from .strategies_constructor import StrategiesConstructor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pulp
|
import pulp
|
||||||
from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus
|
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
|
||||||
except:
|
except:
|
||||||
warnings.warn(f'please install the pulp')
|
warnings.warn(f'please install the pulp')
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,16 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, ShardingStrategy, StrategiesVector,
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
TrainCycleItem)
|
OperationData,
|
||||||
|
ShardingStrategy,
|
||||||
|
StrategiesVector,
|
||||||
|
TrainCycleItem,
|
||||||
|
)
|
||||||
|
from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
|
||||||
|
@ -98,6 +104,12 @@ class NodeHandler(ABC):
|
||||||
|
|
||||||
self.strategies_vector.extend(post_processed_strategies)
|
self.strategies_vector.extend(post_processed_strategies)
|
||||||
|
|
||||||
|
# validating the correctness of the sharding strategy
|
||||||
|
for strategy in self.strategies_vector:
|
||||||
|
for op_data, sharding_spec in strategy.sharding_specs.items():
|
||||||
|
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
|
||||||
|
check_sharding_spec_validity(sharding_spec, op_data.data)
|
||||||
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
||||||
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||||
|
|
|
@ -3,7 +3,7 @@ import operator
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||||
|
|
||||||
from .strategy_generator import StrategyGenerator
|
from .strategy_generator import StrategyGenerator
|
||||||
|
@ -31,8 +31,8 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||||
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
|
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
|
||||||
'''
|
'''
|
||||||
input_op_data = self.op_data['input']
|
input_op_data = self.op_data['input']
|
||||||
assert input_op_data.dim() in (3, 4,
|
assert input_op_data.data.dim() in (
|
||||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||||
|
|
||||||
def update_compute_cost(self, strategy: ShardingStrategy):
|
def update_compute_cost(self, strategy: ShardingStrategy):
|
||||||
'''
|
'''
|
||||||
|
|
|
@ -1,12 +1,17 @@
|
||||||
from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape)
|
from .broadcast import BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape
|
||||||
from .factory import generate_resharding_costs, generate_sharding_spec
|
from .factory import generate_resharding_costs, generate_sharding_spec
|
||||||
from .misc import ignore_sharding_exception
|
from .misc import check_sharding_spec_validity, ignore_sharding_exception
|
||||||
from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size,
|
from .sharding import (
|
||||||
switch_partition_dim, update_partition_dim)
|
enumerate_all_possible_1d_sharding,
|
||||||
|
enumerate_all_possible_2d_sharding,
|
||||||
|
generate_sharding_size,
|
||||||
|
switch_partition_dim,
|
||||||
|
update_partition_dim,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'switch_partition_dim',
|
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||||
'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding',
|
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||||
'generate_sharding_size'
|
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpecException
|
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
||||||
|
|
||||||
__all__ = ['ignore_sharding_exception']
|
__all__ = ['ignore_sharding_exception']
|
||||||
|
|
||||||
|
@ -29,3 +31,37 @@ def ignore_sharding_exception(func):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
|
||||||
|
"""
|
||||||
|
This function checks whether the ShardingSpec is valid for the physical tensor.
|
||||||
|
This check includes 2 items:
|
||||||
|
1. the sharding spec covers all dimensions of the physical tensor
|
||||||
|
2. the sharding spec for each dimension is divisible by the number of devices.
|
||||||
|
#
|
||||||
|
"""
|
||||||
|
# make sure all dims are covered in sharding spec
|
||||||
|
sharding_len = len(sharding_spec.sharding_sequence)
|
||||||
|
tensor_num_dim = tensor.dim()
|
||||||
|
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
|
||||||
|
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
|
||||||
|
assert sharding_len == tensor_num_dim, \
|
||||||
|
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
|
||||||
|
|
||||||
|
# make sure the sharding is valid for each dim
|
||||||
|
for i in range(tensor_num_dim):
|
||||||
|
dim_size = tensor.shape[i]
|
||||||
|
dim_spec = sharding_spec.sharding_sequence[i]
|
||||||
|
|
||||||
|
if str(dim_spec).startswith('S'):
|
||||||
|
devices_str = str(dim_spec).lstrip('S')
|
||||||
|
num_devices = 1
|
||||||
|
|
||||||
|
if '0' in devices_str:
|
||||||
|
num_devices *= num_devices_in_col
|
||||||
|
if '1' in devices_str:
|
||||||
|
num_devices *= num_devices_in_row
|
||||||
|
|
||||||
|
assert dim_size >= num_devices and dim_size % num_devices == 0, \
|
||||||
|
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
|
||||||
|
|
||||||
|
|
||||||
def is_sharding_spec_valid(sharding_spec: ShardingSpec, tensor: torch.Tensor):
|
|
||||||
"""
|
|
||||||
This function checks whether the ShardingSpec is valid for the physical tensor.
|
|
||||||
This check includes 2 items:
|
|
||||||
1. the sharding spec covers all dimensions of the physical tensor
|
|
||||||
2. the sharding spec for each dimension is divisible by the number of devices.
|
|
||||||
#
|
|
||||||
"""
|
|
||||||
# make sure all dims are covered in sharding spec
|
|
||||||
sharding_len = len(sharding_spec.sharding_sequence)
|
|
||||||
tensor_num_dim = tensor.dim()
|
|
||||||
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
|
|
||||||
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
|
|
||||||
assert sharding_len == tensor_num_dim, \
|
|
||||||
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
|
|
||||||
|
|
||||||
# make sure the sharding is valid for each dim
|
|
||||||
for i in range(tensor_num_dim):
|
|
||||||
dim_size = tensor.shape[i]
|
|
||||||
dim_spec = sharding_spec.sharding_sequence[i]
|
|
||||||
|
|
||||||
if str(dim_spec).startswith('S'):
|
|
||||||
devices_str = str(dim_spec).lstrip('S')
|
|
||||||
num_devices = 1
|
|
||||||
|
|
||||||
if '0' in devices_str:
|
|
||||||
num_devices *= num_devices_in_col
|
|
||||||
if '1' in devices_str:
|
|
||||||
num_devices *= num_devices_in_row
|
|
||||||
|
|
||||||
assert dim_size >= num_devices and dim_size % num_devices == 0, \
|
|
||||||
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
|
|
@ -1,11 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import (ConvFunctionHandler, ConvModuleHandler)
|
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
|
||||||
|
|
||||||
|
|
||||||
def test_conv_module_handler():
|
def test_conv_module_handler():
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import (LinearFunctionHandler, LinearModuleHandler)
|
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import LinearFunctionHandler, LinearModuleHandler
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
StrategiesVector)
|
OperationData,
|
||||||
|
OperationDataType,
|
||||||
|
ShardingStrategy,
|
||||||
|
StrategiesVector,
|
||||||
|
)
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.common import \
|
|
||||||
is_sharding_spec_valid
|
|
||||||
|
|
||||||
|
|
||||||
def test_linear_module_handler():
|
def test_linear_module_handler():
|
||||||
|
@ -92,12 +94,6 @@ def test_linear_module_handler():
|
||||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||||
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
|
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
|
||||||
|
|
||||||
# make sure the sharding spec is valid
|
|
||||||
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
|
|
||||||
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('0.weight'))
|
|
||||||
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('0.bias'))
|
|
||||||
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
|
|
||||||
|
|
||||||
# make sure the sharding matches across different operation data
|
# make sure the sharding matches across different operation data
|
||||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||||
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||||
|
@ -182,12 +178,6 @@ def test_linear_function_handler():
|
||||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||||
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
|
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
|
||||||
|
|
||||||
# make sure the sharding spec is valid
|
|
||||||
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
|
|
||||||
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('weight'))
|
|
||||||
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('bias'))
|
|
||||||
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
|
|
||||||
|
|
||||||
# make sure the sharding matches across different operation data
|
# make sure the sharding matches across different operation data
|
||||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||||
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||||
|
|
Loading…
Reference in New Issue