[autoparallel] handled illegal sharding strategy in shape consistency (#1744)

* [autoparallel] handled illegal sharding strategy in shape consistency

* polish code
pull/1745/head
Frank Lee 2022-10-20 12:06:25 +08:00 committed by GitHub
parent 88a79814fb
commit 993b8875b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 109 additions and 89 deletions

View File

@ -3,7 +3,7 @@ from typing import Dict, List
import torch
import torch.nn.functional as F
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
@ -68,7 +68,7 @@ class ConvModuleHandler(ModuleHandler):
dim_partition_dict[1] = second_dim_partition
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
sharding_spec.__init__(sharding_spec.device_mesh, op_data.data.shape, dim_partition_dict)
return strategy

View File

@ -46,6 +46,7 @@ class NodeHandler(ABC):
# TODO: test this function when other handlers are ready
resharding_costs = {}
shape_consistency_manager = ShapeConsistencyManager()
for node in self.predecessor_node:
node_name = str(node)
@ -54,7 +55,9 @@ class NodeHandler(ABC):
assert hasattr(node, 'strategies_vector'), \
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
prev_strategy_vector = node.strategies_vector
prev_sharding_specs = [strategy.get_sharding_spec_by_name(node_name) for strategy in prev_strategy_vector]
prev_sharding_specs = [
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
]
# get the current sharding spec generated by this node handler
op_data = strategy.get_op_data_by_name(node_name)

View File

@ -1,15 +1,17 @@
from ast import NodeTransformer
import torch
from typing import List
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
import builtins
import operator
from ast import NodeTransformer
from copy import deepcopy
from typing import List
import torch
from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.split_module import split_module
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
shape_consistency_manager = ShapeConsistencyManager()

View File

@ -1,16 +1,19 @@
import torch
from dataclasses import dataclass
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from enum import Enum
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union
from colossalai.context.singleton_meta import SingletonMeta
import torch.distributed as dist
import math
from functools import reduce
import operator
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from functools import reduce
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from .comm_spec import *
__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options']
@ -120,12 +123,15 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
cost_dict = comm_spec.get_comm_cost()
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
except ShardingSpecException:
pass
return valid_spec_dict
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict):
@ -223,12 +229,16 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
new_dim_partition_dict.pop(b_index)
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
except ShardingSpecException:
pass
return valid_spec_dict
def get_all_shard_spec(self, source_spec, orig_cost_dict):
@ -275,6 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
return valid_spec_dict
tensor_dims = len(source_spec.entire_shape)
for index in range(tensor_dims):
if index not in source_spec.dim_partition_dict:
shard_list_list = shard_simulator((index, []), legal_sharding_dims)
@ -300,12 +311,15 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
cost_dict = comm_spec.get_comm_cost()
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
except ShardingSpecException:
pass
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict):
@ -403,6 +417,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
return (transform_path, comm_action_sequence, total_cost_dict)
temp_sharding_spec = source_spec
transform_path.append(temp_sharding_spec)
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
while total_steps <= MAX_TRANSFORM_STEPS: