mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] handled illegal sharding strategy in shape consistency (#1744)
* [autoparallel] handled illegal sharding strategy in shape consistency * polish codepull/1745/head
parent
88a79814fb
commit
993b8875b6
|
@ -3,7 +3,7 @@ from typing import Dict, List
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
||||
|
@ -68,7 +68,7 @@ class ConvModuleHandler(ModuleHandler):
|
|||
dim_partition_dict[1] = second_dim_partition
|
||||
|
||||
# re-init the sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, op_data.data.shape, dim_partition_dict)
|
||||
return strategy
|
||||
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ class NodeHandler(ABC):
|
|||
# TODO: test this function when other handlers are ready
|
||||
resharding_costs = {}
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for node in self.predecessor_node:
|
||||
node_name = str(node)
|
||||
|
||||
|
@ -54,7 +55,9 @@ class NodeHandler(ABC):
|
|||
assert hasattr(node, 'strategies_vector'), \
|
||||
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
|
||||
prev_strategy_vector = node.strategies_vector
|
||||
prev_sharding_specs = [strategy.get_sharding_spec_by_name(node_name) for strategy in prev_strategy_vector]
|
||||
prev_sharding_specs = [
|
||||
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
|
||||
]
|
||||
|
||||
# get the current sharding spec generated by this node handler
|
||||
op_data = strategy.get_op_data_by_name(node_name)
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
from ast import NodeTransformer
|
||||
import torch
|
||||
from typing import List
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
from colossalai.fx.passes.split_module import split_module
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
import builtins
|
||||
import operator
|
||||
from ast import NodeTransformer
|
||||
from copy import deepcopy
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.passes.split_module import split_module
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
import torch
|
||||
from dataclasses import dataclass
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
from enum import Enum
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
import torch.distributed as dist
|
||||
import math
|
||||
from functools import reduce
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
|
||||
from .comm_spec import *
|
||||
|
||||
__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options']
|
||||
|
@ -120,12 +123,15 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
cost_dict = comm_spec.get_comm_cost()
|
||||
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
except ShardingSpecException:
|
||||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict):
|
||||
|
@ -223,12 +229,16 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
new_dim_partition_dict.pop(b_index)
|
||||
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
except ShardingSpecException:
|
||||
pass
|
||||
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_shard_spec(self, source_spec, orig_cost_dict):
|
||||
|
@ -275,6 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
return valid_spec_dict
|
||||
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
|
||||
for index in range(tensor_dims):
|
||||
if index not in source_spec.dim_partition_dict:
|
||||
shard_list_list = shard_simulator((index, []), legal_sharding_dims)
|
||||
|
@ -300,12 +311,15 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
cost_dict = comm_spec.get_comm_cost()
|
||||
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
except ShardingSpecException:
|
||||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict):
|
||||
|
@ -403,6 +417,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
return (transform_path, comm_action_sequence, total_cost_dict)
|
||||
|
||||
temp_sharding_spec = source_spec
|
||||
|
||||
transform_path.append(temp_sharding_spec)
|
||||
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
|
||||
while total_steps <= MAX_TRANSFORM_STEPS:
|
||||
|
|
Loading…
Reference in New Issue