[autoparallel] implemented all matmul strategy generator (#1650)

pull/1655/head
Frank Lee 2022-09-27 12:06:25 +08:00 committed by GitHub
parent 03978aad45
commit 30e50c8b4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 440 additions and 76 deletions

View File

@ -50,8 +50,16 @@ class LinearModuleHandler(ModuleHandler):
if op_data.name == "weight": if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and last dim of the linear module weight # switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0] first_dim_partition = dim_partition_dict.pop(-1, None)
last_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if last_dim_partition:
dim_partition_dict[-1] = last_dim_partition
# re-init the sharding spec # re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
@ -111,8 +119,16 @@ class LinearFunctionHandler(NodeHandler):
if op_data.name == str(self.node.args[1]): if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and last dim of the linear module weight # switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0] first_dim_partition = dim_partition_dict.pop(-1, None)
last_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if last_dim_partition:
dim_partition_dict[-1] = last_dim_partition
# re-init the sharding spec # re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)

View File

@ -33,12 +33,12 @@ class NodeHandler(ABC):
Register different sharding strategies for the current node. Register different sharding strategies for the current node.
""" """
strategy_generators = self.get_strategy_generator() strategy_generators = self.get_strategy_generator()
operand_mapping = self.get_operation_data_mapping()
for generator in strategy_generators: for generator in strategy_generators:
strategies = generator.generate(operand_mapping) strategies = generator.generate()
self.strategies_vector.extend(strategies) self.strategies_vector.extend(strategies)
self.strategies_vector = map(self.post_process, self.strategies_vector) strategies_vector = map(self.post_process, self.strategies_vector)
self.strategies_vector = list(strategies_vector)
return self.strategies_vector return self.strategies_vector
def post_process(self, strategy: ShardingStrategy_V2): def post_process(self, strategy: ShardingStrategy_V2):

View File

@ -75,6 +75,12 @@ class OperationData:
if self.logical_shape is None: if self.logical_shape is None:
self.logical_shape = self.data.shape self.logical_shape = self.data.shape
def __repr__(self) -> str:
return f'OperationData(name={self.name}, type={self.type})'
def __hash__(self) -> int:
return hash(f'{self.name}-{self.type}')
@dataclass @dataclass
class TrainCycleItem: class TrainCycleItem:

View File

@ -1,7 +1,4 @@
from cmath import log
from distutils.log import Log
import operator import operator
import torch
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
@ -9,17 +6,148 @@ from .strategy_generator import StrategyGenerator_V2
from typing import List from typing import List
class DotProductStrategyGenerator(StrategyGenerator_V2): class MatMulStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented""" """
pass MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases.
The operation data is defined as `output = input x other + bias`.
"""
@property
def has_bias(self):
return 'bias' in self.op_data
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
size_mapping['bias'] = bias_size
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum([v for k, v in size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + bias_grad
bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + 0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class MatVecStrategyGenerator(StrategyGenerator_V2): class DotProductStrategyGenerator(MatMulStrategyGenerator):
"""TODO: to be implemented"""
pass def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = sharded_input_shape * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_one_dim(self, mesh_dim):
name = f'R = S{mesh_dim} dot S{mesh_dim}'
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {"output": output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]:
strategy_list = []
# do not split dimensions for dot product
# R = R dot R
strategy_list.append(self.no_split())
# split two tensors in the same dimensions
# S = S dot S
strategy_list.append(self.split_one_dim(0))
strategy_list.append(self.split_one_dim(1))
return strategy_list
class LinearProjectionStrategyGenerator(StrategyGenerator_V2): class MatVecStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1
def no_split(self):
name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'other': other_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]:
strategy_list = []
# no split
strategy_list.append(self.no_split())
# split the batch dim for the first tensor only
strategy_list.append(self.split_input_batch(0))
strategy_list.append(self.split_input_batch(1))
return strategy_list
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
# C = AB # C = AB
@ -39,23 +167,6 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
total=fwd_compute_cost + bwd_compute_cost) total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
input_size = self._compute_size_in_bytes(strategy, "input")
other_size = self._compute_size_in_bytes(strategy, "input")
if "bias" in self.op_data:
bias_size = self._compute_size_in_bytes(strategy, "bias")
else:
bias_size = 0
output_size = self._compute_size_in_bytes(strategy, "output")
fwd_mem_cost = MemoryCost(activation=output_size, parameter=other_size + bias_size)
bwd_mem_cost = MemoryCost(activation=input_size + other_size + bias_size, parameter=other_size)
total_mem_cost = MemoryCost(activation=input_size + 2 * output_size + bias_size,
parameter=other_size + bias_size)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def generate(self) -> List[ShardingStrategy_V2]: def generate(self) -> List[ShardingStrategy_V2]:
strategies = [] strategies = []
@ -104,7 +215,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
0: [mesh_dim_0] 0: [mesh_dim_0]
}, },
"other": { "other": {
self.dim_q: [mesh_dim_1] -1: [mesh_dim_1]
}, },
"bias": { "bias": {
-1: [mesh_dim_1] -1: [mesh_dim_1]
@ -143,7 +254,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
-1: [mesh_dim_1] -1: [mesh_dim_1]
}, },
"other": { "other": {
self.dim_p: [mesh_dim_1] 0: [mesh_dim_1]
}, },
"bias": {}, "bias": {},
"output": { "output": {
@ -159,7 +270,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
logical_process_axis=mesh_dim_0) logical_process_axis=mesh_dim_0)
output_comm_spec = self.get_communication_spec( output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["output"], sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1) logical_process_axis=mesh_dim_1)
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec} communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec}
@ -177,8 +288,8 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
-1: [mesh_dim_0] -1: [mesh_dim_0]
}, },
"other": { "other": {
self.dim_p: [mesh_dim_0], 0: [mesh_dim_0],
self.dim_q: [mesh_dim_1] -1: [mesh_dim_1]
}, },
"bias": { "bias": {
-1: [mesh_dim_1] -1: [mesh_dim_1]
@ -192,7 +303,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
# get communication actions # get communication actions
output_comm_spec = self.get_communication_spec( output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'], sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0) logical_process_axis=mesh_dim_0)
input_comm_spec = self.get_communication_spec( input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'], sharding_spec=sharding_spec_mapping['input'],
@ -212,7 +323,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
-1: [mesh_dim] -1: [mesh_dim]
}, },
"other": { "other": {
self.dim_p: [mesh_dim] 0: [mesh_dim]
}, },
"bias": {}, "bias": {},
"output": {}, "output": {},
@ -223,7 +334,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
# get communication action # get communication action
output_comm_spec = self.get_communication_spec( output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'], sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim) logical_process_axis=mesh_dim)
communication_action_mapping = {'output': output_comm_spec} communication_action_mapping = {'output': output_comm_spec}
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
@ -237,7 +348,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
"input": {}, "input": {},
"other": { "other": {
self.dim_q: [mesh_dim] -1: [mesh_dim]
}, },
"bias": { "bias": {
-1: [mesh_dim] -1: [mesh_dim]
@ -294,7 +405,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
-1: [mesh_dim_0, mesh_dim_1] -1: [mesh_dim_0, mesh_dim_1]
}, },
"other": { "other": {
self.dim_p: [mesh_dim_0, mesh_dim_1] 0: [mesh_dim_0, mesh_dim_1]
}, },
"bias": {}, "bias": {},
"output": {}, "output": {},
@ -304,7 +415,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
# get communication action # get communication action
output_comm_spec = self.get_communication_spec( output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'], sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1]) logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {'output': output_comm_spec} communication_action_mapping = {'output': output_comm_spec}
@ -319,7 +430,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
"input": {}, "input": {},
"other": { "other": {
self.dim_q: [mesh_dim_0, mesh_dim_1] -1: [mesh_dim_0, mesh_dim_1]
}, },
"bias": { "bias": {
-1: [mesh_dim_0, mesh_dim_1] -1: [mesh_dim_0, mesh_dim_1]
@ -359,6 +470,190 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1] assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
class BatchedMatMulStrategyGenerator(StrategyGenerator_V2): class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
"""TODO: to be implemented""" """
pass Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j]
"""
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 2 or other_op_data.data.dim() > 2
def split_one_batch_dim(self):
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 1:
mesh_dim = 0
elif 1 in self.device_mesh.mesh_shape:
mesh_dim = self.device_mesh.mesh_shape.index(1)
else:
device_mesh_is_1d = False
if device_mesh_is_1d:
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
# get sharding_spec
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {
0: [mesh_dim]
},
"bias": {},
"output": {
0: [mesh_dim]
}
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
else:
return None
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
}
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0]
},
"bias": {},
"output": {
0: mesh_dim_0,
-2: [mesh_dim_1]
}
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {'other': other_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
}
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {'input': input_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
}
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {'output': output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]:
strategy_list = []
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
strategy = self.split_one_batch_dim()
if strategy:
strategy_list.append(strategy)
# split batch dim of two inputs and the i dim of the first tensor
# SbSi = SbSi x Sb
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
# split batch dim of two inputs and the j of the second tensor
# SbSj = Sb x SbSj
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
# split batch dim of two inputs and the k dim of two inputs
# Sb = SbSk x SbSk, need to all-reduce by k dim
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
# split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1))
strategy_list.append(self.split_two_batch_dim(1, 0))
return strategy_list

View File

@ -7,7 +7,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List, Union, Any from typing import Dict, List, Union, Any
from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem, OperationDataType
class StrategyGenerator_V2(ABC): class StrategyGenerator_V2(ABC):
@ -21,6 +21,10 @@ class StrategyGenerator_V2(ABC):
self.op_data = operation_data_mapping self.op_data = operation_data_mapping
self.device_mesh = device_mesh self.device_mesh = device_mesh
def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.PARAM
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec], def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
communication_action_mapping: Dict[str, CommSpec]): communication_action_mapping: Dict[str, CommSpec]):
""" """
@ -80,7 +84,7 @@ class StrategyGenerator_V2(ABC):
Compute the communication cost involved in the forward and backward iteration. Compute the communication cost involved in the forward and backward iteration.
""" """
comm_cost = TrainCycleItem(fwd=0, bwd=0) comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
def _compute_and_add(data: OperationData, comm_spec: CommSpec): def _compute_and_add(data: OperationData, comm_spec: CommSpec):
num_ele_in_comm = comm_spec.get_comm_cost() num_ele_in_comm = comm_spec.get_comm_cost()
@ -92,7 +96,7 @@ class StrategyGenerator_V2(ABC):
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost. # TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost.
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used, # it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used,
# so total cost is either for fwd or bwd. # so total cost is either for fwd or bwd.
if comm_spec.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD: if comm_spec.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
comm_cost.fwd += cost comm_cost.fwd += cost
elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
comm_cost.fwd += cost comm_cost.fwd += cost
@ -102,9 +106,12 @@ class StrategyGenerator_V2(ABC):
# check if communication action exists # check if communication action exists
# if so, loop over each action and compute the cost of each action # if so, loop over each action and compute the cost of each action
if strategy.communication_actions is not None: if strategy.communication_actions is not None:
for operand, comm_spec in strategy.communication_actions: for operand, comm_spec in strategy.communication_actions.items():
_compute_and_add(operand, comm_spec) _compute_and_add(operand, comm_spec)
# update the total cost
comm_cost.total = comm_cost.fwd + comm_cost.bwd
# update the communication cost attribute in-place # update the communication cost attribute in-place
strategy.communication_cost = comm_cost strategy.communication_cost = comm_cost
return strategy return strategy
@ -146,7 +153,7 @@ class StrategyGenerator_V2(ABC):
pass pass
@abstractmethod @abstractmethod
def validate(self, *args, **kwargs) -> bool: def validate(self) -> bool:
""" """
Validate if the operands are of desired shape. Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation. If True, means this generator can be used for the current operation.

View File

@ -8,9 +8,9 @@ from colossalai.device.device_mesh import DeviceMesh
def test_linear_module_handler(): def test_linear_module_handler():
model = nn.Sequential(nn.Linear(10, 20).to('meta')) model = nn.Sequential(nn.Linear(16, 32).to('meta'))
tracer = ColoTracer() tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')}) graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
@ -34,32 +34,55 @@ def test_linear_module_handler():
assert mapping['input'].name == "input_1" assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 10]) assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 10]) assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['other'].name == "weight" assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10]) assert mapping['other'].data.shape == torch.Size([32, 16])
assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['bias'].name == "bias" assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20]) assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['output'].name == "_0" assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 20]) assert mapping['output'].data.shape == torch.Size([4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
def test_linear_function_handler(): def test_linear_function_handler():
model = nn.Linear(10, 20).to('meta') model = nn.Linear(16, 32).to('meta')
tracer = ColoTracer() tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')}) graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
@ -77,27 +100,50 @@ def test_linear_function_handler():
assert mapping['input'].name == "input_1" assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 10]) assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 10]) assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['other'].name == "weight" assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10]) assert mapping['other'].data.shape == torch.Size([32, 16])
assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['bias'].name == "bias" assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20]) assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['output'].name == "linear" assert mapping['output'].name == "linear"
assert mapping['output'].data.is_meta assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 20]) assert mapping['output'].data.shape == torch.Size([4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
if __name__ == '__main__': if __name__ == '__main__':
test_linear_module_handler() test_linear_module_handler()

View File

@ -1,6 +1,3 @@
from curses import meta
from math import dist
from xml.dom import HierarchyRequestErr
from colossalai.fx.tracer import meta_patch from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops from colossalai.fx.tracer.meta_patch.patched_function import python_ops

View File

@ -1,6 +1,3 @@
from curses import meta
from math import dist
from xml.dom import HierarchyRequestErr
from colossalai.fx.tracer import meta_patch from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops from colossalai.fx.tracer.meta_patch.patched_function import python_ops