ColossalAI/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py

686 lines
29 KiB
Python

from audioop import bias
import operator
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2
from typing import List
class MatMulStrategyGenerator(StrategyGenerator_V2):
"""
MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases.
The operation data is defined as `output = input x other + bias`.
"""
@property
def has_bias(self):
return 'bias' in self.op_data
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
size_mapping['bias'] = bias_size
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum([v for k, v in size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + bias_grad
bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + 0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class DotProductStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = sharded_input_shape * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
return compute_cost
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
communication_action_mapping = {}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_one_dim(self, mesh_dim):
name = f'R = S{mesh_dim} dot S{mesh_dim}'
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {"output": output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]:
strategy_list = []
# do not split dimensions for dot product
# R = R dot R
strategy_list.append(self.no_split())
# split two tensors in the same dimensions
# S = S dot S
strategy_list.append(self.split_one_dim(0))
strategy_list.append(self.split_one_dim(1))
return strategy_list
class MatVecStrategyGenerator(MatMulStrategyGenerator):
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1
def no_split(self):
name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]:
strategy_list = []
# no split
strategy_list.append(self.no_split())
# split the batch dim for the first tensor only
strategy_list.append(self.split_input_batch(0))
strategy_list.append(self.split_input_batch(1))
return strategy_list
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
# C = AB
# C: [M, N], A: [M, P], B: [P, N]
# fwd cost = MNP (only count mul)
# bwd: 2 x fwd_cost
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
dim_m_val = reduce(operator.mul, sharded_input_shape[:-1])
dim_n_val = sharded_other_shape[-1]
dim_p_val = sharded_other_shape[0]
fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=bwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
def generate(self) -> List[ShardingStrategy_V2]:
strategies = []
# SS = SR x RS
strategies.append(self.split_lhs_space_rhs_space(0, 1))
strategies.append(self.split_lhs_space_rhs_space(1, 0))
# SR = SS x SR
strategies.append(self.split_lhs_space_both_contract(0, 1))
strategies.append(self.split_lhs_space_both_contract(1, 0))
# RS = RS x SS
strategies.append(self.split_rhs_space_both_contract(0, 1))
strategies.append(self.split_rhs_space_both_contract(1, 0))
# RR= RS x SR
strategies.append(self.recompute_split_both_contract(0))
strategies.append(self.recompute_split_both_contract(1))
# RS = RR x RS
strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1))
# S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
# RR = RS01 x S01R
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# update mete info on cost
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategies
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
other_comm_spec = self.get_communication_spec(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
bias_comm_spec = self.get_communication_spec(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec, "bias": bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
# get sharding spec mapping
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action mapping
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec, 'bias': bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
# get sharding specs
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
-1: [mesh_dim_1]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim]
},
"other": {
0: [mesh_dim]
},
"bias": {},
"output": {},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'output': output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
-1: [mesh_dim]
},
"bias": {
-1: [mesh_dim]
},
"output": {
-1: [mesh_dim]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'input': input_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communcation_action_mapping = {"other": other_comm_spec, "bias": bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communcation_action_mapping)
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {'output': output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
-1: [mesh_dim_0, mesh_dim_1]
},
"bias": {
-1: [mesh_dim_0, mesh_dim_1]
},
"output": {
-1: [mesh_dim_0, mesh_dim_1]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {'input': input_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def validate(self) -> bool:
assert "input" in self.op_data
assert "other" in self.op_data
# make sure the other has 2 dim
input_data = self.op_data['input']
other_data = self.op_data['other']
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
# check if bias has the same a valid dim
has_bias = "bias" in self.op_data
if has_bias:
bias_data = self.op_data['bias']
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
"""
Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j]
"""
def validate(self) -> bool:
input_op_data = self.op_data['input']
other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 2 or other_op_data.data.dim() > 2
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape)
def split_one_batch_dim(self, mesh_dim):
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
# get sharding_spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {"bias": bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
}
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {"bias": bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0]
},
"bias": {},
"output": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
}
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
}
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {'input': input_comm_spec, 'bias': bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
}
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication actions
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {'output': output_comm_spec, 'bias': bias_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]:
strategy_list = []
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
device_mesh_is_1d = False
if device_mesh_is_1d:
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
# only for 1D device mesh
if len(self.device_mesh.mesh_shape) == 1:
mesh_dim = 0
else:
mesh_dim = self.device_mesh.mesh_shape.index(1)
strategy_list.append(self.split_one_batch_dim(mesh_dim))
else:
# for 2D device mesh
# split batch dim of two inputs and the i dim of the first tensor
# SbSi = SbSi x Sb
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
# split batch dim of two inputs and the j of the second tensor
# SbSj = Sb x SbSj
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
# split batch dim of two inputs and the k dim of two inputs
# Sb = SbSk x SbSk, need to all-reduce by k dim
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
# split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1))
strategy_list.append(self.split_two_batch_dim(1, 0))
return strategy_list