mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] handled illegal sharding strategy in shape consistency (#1744)
* [autoparallel] handled illegal sharding strategy in shape consistency * polish codepull/1745/head
parent
88a79814fb
commit
993b8875b6
|
@ -3,7 +3,7 @@ from typing import Dict, List
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
|
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||||
from .node_handler import ModuleHandler, NodeHandler
|
from .node_handler import ModuleHandler, NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
||||||
|
@ -68,7 +68,7 @@ class ConvModuleHandler(ModuleHandler):
|
||||||
dim_partition_dict[1] = second_dim_partition
|
dim_partition_dict[1] = second_dim_partition
|
||||||
|
|
||||||
# re-init the sharding spec
|
# re-init the sharding spec
|
||||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
sharding_spec.__init__(sharding_spec.device_mesh, op_data.data.shape, dim_partition_dict)
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ class NodeHandler(ABC):
|
||||||
# TODO: test this function when other handlers are ready
|
# TODO: test this function when other handlers are ready
|
||||||
resharding_costs = {}
|
resharding_costs = {}
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
for node in self.predecessor_node:
|
for node in self.predecessor_node:
|
||||||
node_name = str(node)
|
node_name = str(node)
|
||||||
|
|
||||||
|
@ -54,7 +55,9 @@ class NodeHandler(ABC):
|
||||||
assert hasattr(node, 'strategies_vector'), \
|
assert hasattr(node, 'strategies_vector'), \
|
||||||
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
|
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
|
||||||
prev_strategy_vector = node.strategies_vector
|
prev_strategy_vector = node.strategies_vector
|
||||||
prev_sharding_specs = [strategy.get_sharding_spec_by_name(node_name) for strategy in prev_strategy_vector]
|
prev_sharding_specs = [
|
||||||
|
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
|
||||||
|
]
|
||||||
|
|
||||||
# get the current sharding spec generated by this node handler
|
# get the current sharding spec generated by this node handler
|
||||||
op_data = strategy.get_op_data_by_name(node_name)
|
op_data = strategy.get_op_data_by_name(node_name)
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
from ast import NodeTransformer
|
|
||||||
import torch
|
|
||||||
from typing import List
|
|
||||||
from torch.fx import symbolic_trace
|
|
||||||
from torch.fx.node import Node
|
|
||||||
from colossalai.fx.passes.split_module import split_module
|
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
|
||||||
import builtins
|
import builtins
|
||||||
import operator
|
import operator
|
||||||
|
from ast import NodeTransformer
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.fx import symbolic_trace
|
||||||
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx.passes.split_module import split_module
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
|
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,19 @@
|
||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
|
||||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
|
||||||
from enum import Enum
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
|
||||||
from colossalai.context.singleton_meta import SingletonMeta
|
|
||||||
import torch.distributed as dist
|
|
||||||
import math
|
import math
|
||||||
from functools import reduce
|
|
||||||
import operator
|
import operator
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from functools import reduce
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
|
from colossalai.context.singleton_meta import SingletonMeta
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec
|
||||||
|
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||||
|
|
||||||
from .comm_spec import *
|
from .comm_spec import *
|
||||||
|
|
||||||
__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options']
|
__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options']
|
||||||
|
@ -62,10 +65,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
|
|
||||||
def get_all_all_gather_spec(self, source_spec, orig_cost_dict):
|
def get_all_all_gather_spec(self, source_spec, orig_cost_dict):
|
||||||
'''
|
'''
|
||||||
Get all valid sharding specs from source_spec with single all-gather operation, and
|
Get all valid sharding specs from source_spec with single all-gather operation, and
|
||||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||||
For the all-gather operation, we just care about the S dimension.
|
For the all-gather operation, we just care about the S dimension.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||||
orig_cost(float): the original communication cost before this operation.
|
orig_cost(float): the original communication cost before this operation.
|
||||||
|
@ -82,12 +85,12 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
|
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
|
||||||
print(rst_dict)
|
print(rst_dict)
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
{DistSpec:
|
{DistSpec:
|
||||||
shard_sequence: R,S1,R
|
shard_sequence: R,S1,R
|
||||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||||
shard_sequence: S0,R,R
|
shard_sequence: S0,R,R
|
||||||
device_mesh_shape: (4, 4): 0}
|
device_mesh_shape: (4, 4): 0}
|
||||||
'''
|
'''
|
||||||
valid_spec_dict = {}
|
valid_spec_dict = {}
|
||||||
|
@ -120,20 +123,23 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
cost_dict = comm_spec.get_comm_cost()
|
cost_dict = comm_spec.get_comm_cost()
|
||||||
|
|
||||||
# generate new sharding spec
|
# generate new sharding spec
|
||||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
try:
|
||||||
source_spec.entire_shape,
|
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||||
dim_partition_dict=new_dim_partition_dict)
|
source_spec.entire_shape,
|
||||||
for phase, cost in cost_dict.items():
|
dim_partition_dict=new_dim_partition_dict)
|
||||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
for phase, cost in cost_dict.items():
|
||||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||||
|
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||||
|
except ShardingSpecException:
|
||||||
|
pass
|
||||||
return valid_spec_dict
|
return valid_spec_dict
|
||||||
|
|
||||||
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict):
|
def get_all_all_to_all_spec(self, source_spec, orig_cost_dict):
|
||||||
'''
|
'''
|
||||||
Get all valid sharding specs from source_spec with single all-to-all operation, and
|
Get all valid sharding specs from source_spec with single all-to-all operation, and
|
||||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||||
For the all-to-all operation, we just care about the pairs containing S dimension.
|
For the all-to-all operation, we just care about the pairs containing S dimension.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||||
orig_cost(float): the original communication cost before this operation.
|
orig_cost(float): the original communication cost before this operation.
|
||||||
|
@ -150,14 +156,14 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0)
|
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0)
|
||||||
print(rst_dict)
|
print(rst_dict)
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
{DistSpec:
|
{DistSpec:
|
||||||
shard_sequence: S01,R,R
|
shard_sequence: S01,R,R
|
||||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||||
shard_sequence: R,S1,S0
|
shard_sequence: R,S1,S0
|
||||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||||
shard_sequence: S0,R,S1
|
shard_sequence: S0,R,S1
|
||||||
device_mesh_shape: (4, 4): 0}
|
device_mesh_shape: (4, 4): 0}
|
||||||
'''
|
'''
|
||||||
valid_spec_dict = {}
|
valid_spec_dict = {}
|
||||||
|
@ -223,20 +229,24 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
new_dim_partition_dict.pop(b_index)
|
new_dim_partition_dict.pop(b_index)
|
||||||
|
|
||||||
# generate new sharding spec
|
# generate new sharding spec
|
||||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
try:
|
||||||
source_spec.entire_shape,
|
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||||
dim_partition_dict=new_dim_partition_dict)
|
source_spec.entire_shape,
|
||||||
for phase, cost in cost_dict.items():
|
dim_partition_dict=new_dim_partition_dict)
|
||||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
for phase, cost in cost_dict.items():
|
||||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||||
|
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||||
|
except ShardingSpecException:
|
||||||
|
pass
|
||||||
|
|
||||||
return valid_spec_dict
|
return valid_spec_dict
|
||||||
|
|
||||||
def get_all_shard_spec(self, source_spec, orig_cost_dict):
|
def get_all_shard_spec(self, source_spec, orig_cost_dict):
|
||||||
'''
|
'''
|
||||||
Get all valid sharding specs from source_spec with single shard operation, and
|
Get all valid sharding specs from source_spec with single shard operation, and
|
||||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||||
For the sharding operation, we just care about legal sharding dimensions.
|
For the sharding operation, we just care about legal sharding dimensions.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||||
orig_cost(float): the original communication cost before this operation.
|
orig_cost(float): the original communication cost before this operation.
|
||||||
|
@ -253,14 +263,14 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0)
|
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0)
|
||||||
print(rst_dict)
|
print(rst_dict)
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
{DistSpec:
|
{DistSpec:
|
||||||
shard_sequence: S01,R,R
|
shard_sequence: S01,R,R
|
||||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||||
shard_sequence: S0,S1,R
|
shard_sequence: S0,S1,R
|
||||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||||
shard_sequence: S0,R,S1
|
shard_sequence: S0,R,S1
|
||||||
device_mesh_shape: (4, 4): 0}
|
device_mesh_shape: (4, 4): 0}
|
||||||
'''
|
'''
|
||||||
valid_spec_dict = {}
|
valid_spec_dict = {}
|
||||||
|
@ -275,6 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
return valid_spec_dict
|
return valid_spec_dict
|
||||||
|
|
||||||
tensor_dims = len(source_spec.entire_shape)
|
tensor_dims = len(source_spec.entire_shape)
|
||||||
|
|
||||||
for index in range(tensor_dims):
|
for index in range(tensor_dims):
|
||||||
if index not in source_spec.dim_partition_dict:
|
if index not in source_spec.dim_partition_dict:
|
||||||
shard_list_list = shard_simulator((index, []), legal_sharding_dims)
|
shard_list_list = shard_simulator((index, []), legal_sharding_dims)
|
||||||
|
@ -300,23 +311,26 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
cost_dict = comm_spec.get_comm_cost()
|
cost_dict = comm_spec.get_comm_cost()
|
||||||
|
|
||||||
# generate new sharding spec
|
# generate new sharding spec
|
||||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
try:
|
||||||
source_spec.entire_shape,
|
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||||
dim_partition_dict=new_dim_partition_dict)
|
source_spec.entire_shape,
|
||||||
for phase, cost in cost_dict.items():
|
dim_partition_dict=new_dim_partition_dict)
|
||||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
for phase, cost in cost_dict.items():
|
||||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||||
|
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||||
|
except ShardingSpecException:
|
||||||
|
pass
|
||||||
return valid_spec_dict
|
return valid_spec_dict
|
||||||
|
|
||||||
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict):
|
def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict):
|
||||||
'''
|
'''
|
||||||
Get all valid sharding specs from source_spec with one step transform, and
|
Get all valid sharding specs from source_spec with one step transform, and
|
||||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||||
Note:
|
Note:
|
||||||
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
|
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
|
||||||
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
|
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
|
||||||
we could safely put them together.
|
we could safely put them together.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||||
orig_cost(float): the original communication cost before this operation.
|
orig_cost(float): the original communication cost before this operation.
|
||||||
|
@ -343,7 +357,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
Repeat above steps until the source spec transform to target spec.
|
Repeat above steps until the source spec transform to target spec.
|
||||||
|
|
||||||
During finding the transform path, commucation cost will be accumulated, and it
|
During finding the transform path, commucation cost will be accumulated, and it
|
||||||
will be finally used in auto parallel solver.
|
will be finally used in auto parallel solver.
|
||||||
|
|
||||||
Additionally, to avoid repeating the path search in runtime, we cached all solved path
|
Additionally, to avoid repeating the path search in runtime, we cached all solved path
|
||||||
in auto parallel strategy building time, which could handle most of cases in runtime.
|
in auto parallel strategy building time, which could handle most of cases in runtime.
|
||||||
|
@ -361,30 +375,30 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
Example:
|
Example:
|
||||||
dim_partition_source = {1: [0, 1]}
|
dim_partition_source = {1: [0, 1]}
|
||||||
dim_partition_target = {0: [0, 1]}
|
dim_partition_target = {0: [0, 1]}
|
||||||
# DistSpec:
|
# DistSpec:
|
||||||
# shard_sequence: R,S01,R
|
# shard_sequence: R,S01,R
|
||||||
# device_mesh_shape: (4, 4)
|
# device_mesh_shape: (4, 4)
|
||||||
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
|
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
|
||||||
# DistSpec:
|
# DistSpec:
|
||||||
# shard_sequence: S01,R,R
|
# shard_sequence: S01,R,R
|
||||||
# device_mesh_shape: (4, 4)
|
# device_mesh_shape: (4, 4)
|
||||||
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
|
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
|
||||||
transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(sharding_spec_source, sharding_spec_target)
|
transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(sharding_spec_source, sharding_spec_target)
|
||||||
print(f'transform_path: {transform_path}')
|
print(f'transform_path: {transform_path}')
|
||||||
print(f'comm_action_sequence: {comm_action_sequence}')
|
print(f'comm_action_sequence: {comm_action_sequence}')
|
||||||
print(f'total_cost: {total_cost}')
|
print(f'total_cost: {total_cost}')
|
||||||
|
|
||||||
output:
|
output:
|
||||||
transform_path: [DistSpec:
|
transform_path: [DistSpec:
|
||||||
shard_sequence: R,S01,R
|
shard_sequence: R,S01,R
|
||||||
device_mesh_shape: (4, 4), DistSpec:
|
device_mesh_shape: (4, 4), DistSpec:
|
||||||
shard_sequence: R,S0,R
|
shard_sequence: R,S0,R
|
||||||
device_mesh_shape: (4, 4), DistSpec:
|
device_mesh_shape: (4, 4), DistSpec:
|
||||||
shard_sequence: S0,R,R
|
shard_sequence: S0,R,R
|
||||||
device_mesh_shape: (4, 4), DistSpec:
|
device_mesh_shape: (4, 4), DistSpec:
|
||||||
shard_sequence: S01,R,R
|
shard_sequence: S01,R,R
|
||||||
device_mesh_shape: (4, 4)]
|
device_mesh_shape: (4, 4)]
|
||||||
comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1),
|
comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1),
|
||||||
CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0),
|
CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0),
|
||||||
CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)]
|
CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)]
|
||||||
total_cost: 12294.402000000002
|
total_cost: 12294.402000000002
|
||||||
|
@ -403,6 +417,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
return (transform_path, comm_action_sequence, total_cost_dict)
|
return (transform_path, comm_action_sequence, total_cost_dict)
|
||||||
|
|
||||||
temp_sharding_spec = source_spec
|
temp_sharding_spec = source_spec
|
||||||
|
|
||||||
transform_path.append(temp_sharding_spec)
|
transform_path.append(temp_sharding_spec)
|
||||||
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
|
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
|
||||||
while total_steps <= MAX_TRANSFORM_STEPS:
|
while total_steps <= MAX_TRANSFORM_STEPS:
|
||||||
|
@ -437,13 +452,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
|
|
||||||
def apply(self, tensor_with_sharding_spec, target_spec):
|
def apply(self, tensor_with_sharding_spec, target_spec):
|
||||||
'''
|
'''
|
||||||
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
|
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
|
||||||
shape_consistency method.
|
shape_consistency method.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec.
|
tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec.
|
||||||
target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec.
|
target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
|
@ -459,7 +474,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
# shard_sequence: S0,R
|
# shard_sequence: S0,R
|
||||||
# device_mesh_shape: (2, 2)
|
# device_mesh_shape: (2, 2)
|
||||||
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
|
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
|
||||||
|
|
||||||
# DistSpec:
|
# DistSpec:
|
||||||
# shard_sequence: R,S0
|
# shard_sequence: R,S0
|
||||||
# device_mesh_shape: (2, 2)
|
# device_mesh_shape: (2, 2)
|
||||||
|
@ -481,13 +496,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
tensor_to_comm.sharding_spec = sharding_spec_source
|
tensor_to_comm.sharding_spec = sharding_spec_source
|
||||||
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
|
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
|
||||||
print(tensor_to_comm)
|
print(tensor_to_comm)
|
||||||
|
|
||||||
Output in rank0 and rank2:
|
Output in rank0 and rank2:
|
||||||
tensor([[0.],
|
tensor([[0.],
|
||||||
[0.],
|
[0.],
|
||||||
[2.],
|
[2.],
|
||||||
[2.]])
|
[2.]])
|
||||||
|
|
||||||
Output in rank1 and rank3:
|
Output in rank1 and rank3:
|
||||||
tensor([[1.],
|
tensor([[1.],
|
||||||
[1.],
|
[1.],
|
||||||
|
@ -505,4 +520,4 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
for comm_spec in comm_action_sequence:
|
for comm_spec in comm_action_sequence:
|
||||||
comm_spec.covert_spec_to_action(tensor)
|
comm_spec.covert_spec_to_action(tensor)
|
||||||
tensor.sharding_spec = target_spec
|
tensor.sharding_spec = target_spec
|
||||||
return tensor
|
return tensor
|
||||||
|
|
Loading…
Reference in New Issue