mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] implemented all matmul strategy generator (#1650)
parent
03978aad45
commit
30e50c8b4a
|
@ -50,8 +50,16 @@ class LinearModuleHandler(ModuleHandler):
|
||||||
if op_data.name == "weight":
|
if op_data.name == "weight":
|
||||||
assert op_data.logical_shape != op_data.data.shape
|
assert op_data.logical_shape != op_data.data.shape
|
||||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||||
|
|
||||||
# switch first and last dim of the linear module weight
|
# switch first and last dim of the linear module weight
|
||||||
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
|
first_dim_partition = dim_partition_dict.pop(-1, None)
|
||||||
|
last_dim_partition = dim_partition_dict.pop(0, None)
|
||||||
|
|
||||||
|
if first_dim_partition:
|
||||||
|
dim_partition_dict[0] = first_dim_partition
|
||||||
|
|
||||||
|
if last_dim_partition:
|
||||||
|
dim_partition_dict[-1] = last_dim_partition
|
||||||
|
|
||||||
# re-init the sharding spec
|
# re-init the sharding spec
|
||||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||||
|
@ -111,8 +119,16 @@ class LinearFunctionHandler(NodeHandler):
|
||||||
if op_data.name == str(self.node.args[1]):
|
if op_data.name == str(self.node.args[1]):
|
||||||
assert op_data.logical_shape != op_data.data.shape
|
assert op_data.logical_shape != op_data.data.shape
|
||||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||||
|
|
||||||
# switch first and last dim of the linear module weight
|
# switch first and last dim of the linear module weight
|
||||||
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
|
first_dim_partition = dim_partition_dict.pop(-1, None)
|
||||||
|
last_dim_partition = dim_partition_dict.pop(0, None)
|
||||||
|
|
||||||
|
if first_dim_partition:
|
||||||
|
dim_partition_dict[0] = first_dim_partition
|
||||||
|
|
||||||
|
if last_dim_partition:
|
||||||
|
dim_partition_dict[-1] = last_dim_partition
|
||||||
|
|
||||||
# re-init the sharding spec
|
# re-init the sharding spec
|
||||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||||
|
|
|
@ -33,12 +33,12 @@ class NodeHandler(ABC):
|
||||||
Register different sharding strategies for the current node.
|
Register different sharding strategies for the current node.
|
||||||
"""
|
"""
|
||||||
strategy_generators = self.get_strategy_generator()
|
strategy_generators = self.get_strategy_generator()
|
||||||
operand_mapping = self.get_operation_data_mapping()
|
|
||||||
for generator in strategy_generators:
|
for generator in strategy_generators:
|
||||||
strategies = generator.generate(operand_mapping)
|
strategies = generator.generate()
|
||||||
self.strategies_vector.extend(strategies)
|
self.strategies_vector.extend(strategies)
|
||||||
|
|
||||||
self.strategies_vector = map(self.post_process, self.strategies_vector)
|
strategies_vector = map(self.post_process, self.strategies_vector)
|
||||||
|
self.strategies_vector = list(strategies_vector)
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
||||||
def post_process(self, strategy: ShardingStrategy_V2):
|
def post_process(self, strategy: ShardingStrategy_V2):
|
||||||
|
|
|
@ -75,6 +75,12 @@ class OperationData:
|
||||||
if self.logical_shape is None:
|
if self.logical_shape is None:
|
||||||
self.logical_shape = self.data.shape
|
self.logical_shape = self.data.shape
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'OperationData(name={self.name}, type={self.type})'
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(f'{self.name}-{self.type}')
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainCycleItem:
|
class TrainCycleItem:
|
||||||
|
|
|
@ -1,7 +1,4 @@
|
||||||
from cmath import log
|
|
||||||
from distutils.log import Log
|
|
||||||
import operator
|
import operator
|
||||||
import torch
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
|
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
|
||||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||||
|
@ -9,17 +6,148 @@ from .strategy_generator import StrategyGenerator_V2
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class DotProductStrategyGenerator(StrategyGenerator_V2):
|
class MatMulStrategyGenerator(StrategyGenerator_V2):
|
||||||
"""TODO: to be implemented"""
|
"""
|
||||||
pass
|
MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases.
|
||||||
|
The operation data is defined as `output = input x other + bias`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_bias(self):
|
||||||
|
return 'bias' in self.op_data
|
||||||
|
|
||||||
|
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
|
size_mapping = {
|
||||||
|
'input': self._compute_size_in_bytes(strategy, "input"),
|
||||||
|
'other': self._compute_size_in_bytes(strategy, "other"),
|
||||||
|
'output': self._compute_size_in_bytes(strategy, "output")
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
bias_size = self._compute_size_in_bytes(strategy, "bias")
|
||||||
|
size_mapping['bias'] = bias_size
|
||||||
|
|
||||||
|
# compute fwd cost incurred
|
||||||
|
# fwd_cost = input + other + bias + output
|
||||||
|
fwd_activation_cost = sum([v for k, v in size_mapping.items() if not self.is_param(k)])
|
||||||
|
fwd_parameter_cost = sum([v for k, v in size_mapping.items() if self.is_param(k)])
|
||||||
|
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
|
||||||
|
|
||||||
|
# compute bwd cost incurred
|
||||||
|
# bwd_cost = input_grad + bias_grad
|
||||||
|
bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']])
|
||||||
|
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
|
||||||
|
|
||||||
|
# compute total cost
|
||||||
|
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
|
||||||
|
parameter=fwd_parameter_cost + 0)
|
||||||
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
|
strategy.memory_cost = memory_cost
|
||||||
|
|
||||||
|
|
||||||
class MatVecStrategyGenerator(StrategyGenerator_V2):
|
class DotProductStrategyGenerator(MatMulStrategyGenerator):
|
||||||
"""TODO: to be implemented"""
|
|
||||||
pass
|
def validate(self) -> bool:
|
||||||
|
input_op_data = self.op_data['input']
|
||||||
|
other_op_data = self.op_data['other']
|
||||||
|
assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1
|
||||||
|
|
||||||
|
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
|
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
|
||||||
|
fwd_compute_cost = sharded_input_shape[0]
|
||||||
|
bwd_compute_cost = sharded_input_shape * 2
|
||||||
|
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
|
||||||
|
bwd=bwd_compute_cost,
|
||||||
|
total=fwd_compute_cost + bwd_compute_cost)
|
||||||
|
return compute_cost
|
||||||
|
|
||||||
|
def no_split(self):
|
||||||
|
name = f'R = R dot R'
|
||||||
|
dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
communication_action_mapping = {}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def split_one_dim(self, mesh_dim):
|
||||||
|
name = f'R = S{mesh_dim} dot S{mesh_dim}'
|
||||||
|
|
||||||
|
# get sharding spec
|
||||||
|
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
|
||||||
|
# get communication action
|
||||||
|
output_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['output'],
|
||||||
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
|
logical_process_axis=mesh_dim)
|
||||||
|
communication_action_mapping = {"output": output_comm_spec}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def generate(self) -> List[ShardingStrategy_V2]:
|
||||||
|
strategy_list = []
|
||||||
|
|
||||||
|
# do not split dimensions for dot product
|
||||||
|
# R = R dot R
|
||||||
|
strategy_list.append(self.no_split())
|
||||||
|
|
||||||
|
# split two tensors in the same dimensions
|
||||||
|
# S = S dot S
|
||||||
|
strategy_list.append(self.split_one_dim(0))
|
||||||
|
strategy_list.append(self.split_one_dim(1))
|
||||||
|
|
||||||
|
return strategy_list
|
||||||
|
|
||||||
|
|
||||||
class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
input_op_data = self.op_data['input']
|
||||||
|
other_op_data = self.op_data['other']
|
||||||
|
assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1
|
||||||
|
|
||||||
|
def no_split(self):
|
||||||
|
name = "R = R x R"
|
||||||
|
dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping={})
|
||||||
|
|
||||||
|
def split_input_batch(self, mesh_dim):
|
||||||
|
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
||||||
|
|
||||||
|
# get sharding spec
|
||||||
|
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
|
||||||
|
# get communication action
|
||||||
|
other_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['other'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim)
|
||||||
|
communication_action_mapping = {'other': other_comm_spec}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def generate(self) -> List[ShardingStrategy_V2]:
|
||||||
|
strategy_list = []
|
||||||
|
|
||||||
|
# no split
|
||||||
|
strategy_list.append(self.no_split())
|
||||||
|
|
||||||
|
# split the batch dim for the first tensor only
|
||||||
|
strategy_list.append(self.split_input_batch(0))
|
||||||
|
strategy_list.append(self.split_input_batch(1))
|
||||||
|
|
||||||
|
return strategy_list
|
||||||
|
|
||||||
|
|
||||||
|
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||||
|
|
||||||
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
# C = AB
|
# C = AB
|
||||||
|
@ -39,23 +167,6 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
total=fwd_compute_cost + bwd_compute_cost)
|
total=fwd_compute_cost + bwd_compute_cost)
|
||||||
strategy.compute_cost = compute_cost
|
strategy.compute_cost = compute_cost
|
||||||
|
|
||||||
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
|
||||||
input_size = self._compute_size_in_bytes(strategy, "input")
|
|
||||||
other_size = self._compute_size_in_bytes(strategy, "input")
|
|
||||||
|
|
||||||
if "bias" in self.op_data:
|
|
||||||
bias_size = self._compute_size_in_bytes(strategy, "bias")
|
|
||||||
else:
|
|
||||||
bias_size = 0
|
|
||||||
output_size = self._compute_size_in_bytes(strategy, "output")
|
|
||||||
|
|
||||||
fwd_mem_cost = MemoryCost(activation=output_size, parameter=other_size + bias_size)
|
|
||||||
bwd_mem_cost = MemoryCost(activation=input_size + other_size + bias_size, parameter=other_size)
|
|
||||||
total_mem_cost = MemoryCost(activation=input_size + 2 * output_size + bias_size,
|
|
||||||
parameter=other_size + bias_size)
|
|
||||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
|
||||||
strategy.memory_cost = memory_cost
|
|
||||||
|
|
||||||
def generate(self) -> List[ShardingStrategy_V2]:
|
def generate(self) -> List[ShardingStrategy_V2]:
|
||||||
strategies = []
|
strategies = []
|
||||||
|
|
||||||
|
@ -104,7 +215,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
0: [mesh_dim_0]
|
0: [mesh_dim_0]
|
||||||
},
|
},
|
||||||
"other": {
|
"other": {
|
||||||
self.dim_q: [mesh_dim_1]
|
-1: [mesh_dim_1]
|
||||||
},
|
},
|
||||||
"bias": {
|
"bias": {
|
||||||
-1: [mesh_dim_1]
|
-1: [mesh_dim_1]
|
||||||
|
@ -143,7 +254,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
-1: [mesh_dim_1]
|
-1: [mesh_dim_1]
|
||||||
},
|
},
|
||||||
"other": {
|
"other": {
|
||||||
self.dim_p: [mesh_dim_1]
|
0: [mesh_dim_1]
|
||||||
},
|
},
|
||||||
"bias": {},
|
"bias": {},
|
||||||
"output": {
|
"output": {
|
||||||
|
@ -159,7 +270,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
logical_process_axis=mesh_dim_0)
|
logical_process_axis=mesh_dim_0)
|
||||||
output_comm_spec = self.get_communication_spec(
|
output_comm_spec = self.get_communication_spec(
|
||||||
sharding_spec=sharding_spec_mapping["output"],
|
sharding_spec=sharding_spec_mapping["output"],
|
||||||
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
logical_process_axis=mesh_dim_1)
|
logical_process_axis=mesh_dim_1)
|
||||||
|
|
||||||
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec}
|
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec}
|
||||||
|
@ -177,8 +288,8 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
-1: [mesh_dim_0]
|
-1: [mesh_dim_0]
|
||||||
},
|
},
|
||||||
"other": {
|
"other": {
|
||||||
self.dim_p: [mesh_dim_0],
|
0: [mesh_dim_0],
|
||||||
self.dim_q: [mesh_dim_1]
|
-1: [mesh_dim_1]
|
||||||
},
|
},
|
||||||
"bias": {
|
"bias": {
|
||||||
-1: [mesh_dim_1]
|
-1: [mesh_dim_1]
|
||||||
|
@ -192,7 +303,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
# get communication actions
|
# get communication actions
|
||||||
output_comm_spec = self.get_communication_spec(
|
output_comm_spec = self.get_communication_spec(
|
||||||
sharding_spec=sharding_spec_mapping['output'],
|
sharding_spec=sharding_spec_mapping['output'],
|
||||||
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
logical_process_axis=mesh_dim_0)
|
logical_process_axis=mesh_dim_0)
|
||||||
input_comm_spec = self.get_communication_spec(
|
input_comm_spec = self.get_communication_spec(
|
||||||
sharding_spec=sharding_spec_mapping['input'],
|
sharding_spec=sharding_spec_mapping['input'],
|
||||||
|
@ -212,7 +323,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
-1: [mesh_dim]
|
-1: [mesh_dim]
|
||||||
},
|
},
|
||||||
"other": {
|
"other": {
|
||||||
self.dim_p: [mesh_dim]
|
0: [mesh_dim]
|
||||||
},
|
},
|
||||||
"bias": {},
|
"bias": {},
|
||||||
"output": {},
|
"output": {},
|
||||||
|
@ -223,7 +334,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
# get communication action
|
# get communication action
|
||||||
output_comm_spec = self.get_communication_spec(
|
output_comm_spec = self.get_communication_spec(
|
||||||
sharding_spec=sharding_spec_mapping['output'],
|
sharding_spec=sharding_spec_mapping['output'],
|
||||||
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
logical_process_axis=mesh_dim)
|
logical_process_axis=mesh_dim)
|
||||||
communication_action_mapping = {'output': output_comm_spec}
|
communication_action_mapping = {'output': output_comm_spec}
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
|
@ -237,7 +348,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
"input": {},
|
"input": {},
|
||||||
"other": {
|
"other": {
|
||||||
self.dim_q: [mesh_dim]
|
-1: [mesh_dim]
|
||||||
},
|
},
|
||||||
"bias": {
|
"bias": {
|
||||||
-1: [mesh_dim]
|
-1: [mesh_dim]
|
||||||
|
@ -294,7 +405,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
-1: [mesh_dim_0, mesh_dim_1]
|
-1: [mesh_dim_0, mesh_dim_1]
|
||||||
},
|
},
|
||||||
"other": {
|
"other": {
|
||||||
self.dim_p: [mesh_dim_0, mesh_dim_1]
|
0: [mesh_dim_0, mesh_dim_1]
|
||||||
},
|
},
|
||||||
"bias": {},
|
"bias": {},
|
||||||
"output": {},
|
"output": {},
|
||||||
|
@ -304,7 +415,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
# get communication action
|
# get communication action
|
||||||
output_comm_spec = self.get_communication_spec(
|
output_comm_spec = self.get_communication_spec(
|
||||||
sharding_spec=sharding_spec_mapping['output'],
|
sharding_spec=sharding_spec_mapping['output'],
|
||||||
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||||
communication_action_mapping = {'output': output_comm_spec}
|
communication_action_mapping = {'output': output_comm_spec}
|
||||||
|
|
||||||
|
@ -319,7 +430,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
"input": {},
|
"input": {},
|
||||||
"other": {
|
"other": {
|
||||||
self.dim_q: [mesh_dim_0, mesh_dim_1]
|
-1: [mesh_dim_0, mesh_dim_1]
|
||||||
},
|
},
|
||||||
"bias": {
|
"bias": {
|
||||||
-1: [mesh_dim_0, mesh_dim_1]
|
-1: [mesh_dim_0, mesh_dim_1]
|
||||||
|
@ -359,6 +470,190 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
|
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
|
||||||
|
|
||||||
|
|
||||||
class BatchedMatMulStrategyGenerator(StrategyGenerator_V2):
|
class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||||
"""TODO: to be implemented"""
|
"""
|
||||||
pass
|
Generate sharding strategies for the batched matrix multiplication.
|
||||||
|
|
||||||
|
A batched matrix multiplication can be viewed as
|
||||||
|
[b, i, k] x [b, k, j] -> [b, i, j]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
input_op_data = self.op_data['input']
|
||||||
|
other_op_data = self.op_data['other']
|
||||||
|
assert input_op_data.data.dim() > 2 or other_op_data.data.dim() > 2
|
||||||
|
|
||||||
|
def split_one_batch_dim(self):
|
||||||
|
device_mesh_is_1d = True
|
||||||
|
if len(self.device_mesh.mesh_shape) == 1:
|
||||||
|
mesh_dim = 0
|
||||||
|
elif 1 in self.device_mesh.mesh_shape:
|
||||||
|
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||||
|
else:
|
||||||
|
device_mesh_is_1d = False
|
||||||
|
|
||||||
|
if device_mesh_is_1d:
|
||||||
|
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||||
|
|
||||||
|
# get sharding_spec
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
0: [mesh_dim]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
|
||||||
|
# get communication actions
|
||||||
|
communication_action_mapping = {}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
0: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0, mesh_dim_1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
|
||||||
|
# get communication actions
|
||||||
|
communication_action_mapping = {}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-2: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
0: [mesh_dim_0]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: mesh_dim_0,
|
||||||
|
-2: [mesh_dim_1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
|
||||||
|
# get communication actions
|
||||||
|
other_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['other'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_1)
|
||||||
|
communication_action_mapping = {'other': other_comm_spec}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
|
||||||
|
# get communication actions
|
||||||
|
input_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['input'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_1)
|
||||||
|
communication_action_mapping = {'input': input_comm_spec}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-2: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-2: [mesh_dim_1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
|
||||||
|
|
||||||
|
# get communication actions
|
||||||
|
output_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['output'],
|
||||||
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
|
logical_process_axis=mesh_dim_1)
|
||||||
|
communication_action_mapping = {'output': output_comm_spec}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def generate(self) -> List[ShardingStrategy_V2]:
|
||||||
|
strategy_list = []
|
||||||
|
|
||||||
|
# split only the batch dimension
|
||||||
|
# Sb = Sb x Sb
|
||||||
|
# can be None as it is only for 1D device mesh
|
||||||
|
strategy = self.split_one_batch_dim()
|
||||||
|
if strategy:
|
||||||
|
strategy_list.append(strategy)
|
||||||
|
|
||||||
|
# split batch dim of two inputs and the i dim of the first tensor
|
||||||
|
# SbSi = SbSi x Sb
|
||||||
|
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
|
||||||
|
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
|
||||||
|
|
||||||
|
# split batch dim of two inputs and the j of the second tensor
|
||||||
|
# SbSj = Sb x SbSj
|
||||||
|
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
|
||||||
|
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
|
||||||
|
|
||||||
|
# split batch dim of two inputs and the k dim of two inputs
|
||||||
|
# Sb = SbSk x SbSk, need to all-reduce by k dim
|
||||||
|
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
|
||||||
|
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
|
||||||
|
|
||||||
|
# split two batch dim
|
||||||
|
strategy_list.append(self.split_two_batch_dim(0, 1))
|
||||||
|
strategy_list.append(self.split_two_batch_dim(1, 0))
|
||||||
|
|
||||||
|
return strategy_list
|
||||||
|
|
|
@ -7,7 +7,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from typing import Dict, List, Union, Any
|
from typing import Dict, List, Union, Any
|
||||||
from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem
|
from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem, OperationDataType
|
||||||
|
|
||||||
|
|
||||||
class StrategyGenerator_V2(ABC):
|
class StrategyGenerator_V2(ABC):
|
||||||
|
@ -21,6 +21,10 @@ class StrategyGenerator_V2(ABC):
|
||||||
self.op_data = operation_data_mapping
|
self.op_data = operation_data_mapping
|
||||||
self.device_mesh = device_mesh
|
self.device_mesh = device_mesh
|
||||||
|
|
||||||
|
def is_param(self, op_data_name):
|
||||||
|
other_data = self.op_data[op_data_name]
|
||||||
|
return other_data.type == OperationDataType.PARAM
|
||||||
|
|
||||||
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
|
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
|
||||||
communication_action_mapping: Dict[str, CommSpec]):
|
communication_action_mapping: Dict[str, CommSpec]):
|
||||||
"""
|
"""
|
||||||
|
@ -80,7 +84,7 @@ class StrategyGenerator_V2(ABC):
|
||||||
Compute the communication cost involved in the forward and backward iteration.
|
Compute the communication cost involved in the forward and backward iteration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
comm_cost = TrainCycleItem(fwd=0, bwd=0)
|
comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||||
|
|
||||||
def _compute_and_add(data: OperationData, comm_spec: CommSpec):
|
def _compute_and_add(data: OperationData, comm_spec: CommSpec):
|
||||||
num_ele_in_comm = comm_spec.get_comm_cost()
|
num_ele_in_comm = comm_spec.get_comm_cost()
|
||||||
|
@ -92,7 +96,7 @@ class StrategyGenerator_V2(ABC):
|
||||||
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost.
|
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost.
|
||||||
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used,
|
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used,
|
||||||
# so total cost is either for fwd or bwd.
|
# so total cost is either for fwd or bwd.
|
||||||
if comm_spec.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD:
|
if comm_spec.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
|
||||||
comm_cost.fwd += cost
|
comm_cost.fwd += cost
|
||||||
elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
|
elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
|
||||||
comm_cost.fwd += cost
|
comm_cost.fwd += cost
|
||||||
|
@ -102,9 +106,12 @@ class StrategyGenerator_V2(ABC):
|
||||||
# check if communication action exists
|
# check if communication action exists
|
||||||
# if so, loop over each action and compute the cost of each action
|
# if so, loop over each action and compute the cost of each action
|
||||||
if strategy.communication_actions is not None:
|
if strategy.communication_actions is not None:
|
||||||
for operand, comm_spec in strategy.communication_actions:
|
for operand, comm_spec in strategy.communication_actions.items():
|
||||||
_compute_and_add(operand, comm_spec)
|
_compute_and_add(operand, comm_spec)
|
||||||
|
|
||||||
|
# update the total cost
|
||||||
|
comm_cost.total = comm_cost.fwd + comm_cost.bwd
|
||||||
|
|
||||||
# update the communication cost attribute in-place
|
# update the communication cost attribute in-place
|
||||||
strategy.communication_cost = comm_cost
|
strategy.communication_cost = comm_cost
|
||||||
return strategy
|
return strategy
|
||||||
|
@ -146,7 +153,7 @@ class StrategyGenerator_V2(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def validate(self, *args, **kwargs) -> bool:
|
def validate(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Validate if the operands are of desired shape.
|
Validate if the operands are of desired shape.
|
||||||
If True, means this generator can be used for the current operation.
|
If True, means this generator can be used for the current operation.
|
||||||
|
|
|
@ -8,9 +8,9 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
|
||||||
def test_linear_module_handler():
|
def test_linear_module_handler():
|
||||||
model = nn.Sequential(nn.Linear(10, 20).to('meta'))
|
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
|
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
|
||||||
|
@ -34,32 +34,55 @@ def test_linear_module_handler():
|
||||||
|
|
||||||
assert mapping['input'].name == "input_1"
|
assert mapping['input'].name == "input_1"
|
||||||
assert mapping['input'].data.is_meta
|
assert mapping['input'].data.is_meta
|
||||||
assert mapping['input'].data.shape == torch.Size([4, 10])
|
assert mapping['input'].data.shape == torch.Size([4, 16])
|
||||||
assert mapping['input'].type == OperationDataType.ARG
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
assert mapping['input'].logical_shape == torch.Size([4, 10])
|
assert mapping['input'].logical_shape == torch.Size([4, 16])
|
||||||
|
|
||||||
assert mapping['other'].name == "weight"
|
assert mapping['other'].name == "weight"
|
||||||
assert mapping['other'].data.is_meta
|
assert mapping['other'].data.is_meta
|
||||||
assert mapping['other'].data.shape == torch.Size([20, 10])
|
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||||
assert mapping['other'].type == OperationDataType.PARAM
|
assert mapping['other'].type == OperationDataType.PARAM
|
||||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||||
|
|
||||||
assert mapping['bias'].name == "bias"
|
assert mapping['bias'].name == "bias"
|
||||||
assert mapping['bias'].data.is_meta
|
assert mapping['bias'].data.is_meta
|
||||||
assert mapping['bias'].data.shape == torch.Size([20])
|
assert mapping['bias'].data.shape == torch.Size([32])
|
||||||
assert mapping['bias'].type == OperationDataType.PARAM
|
assert mapping['bias'].type == OperationDataType.PARAM
|
||||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||||
|
|
||||||
assert mapping['output'].name == "_0"
|
assert mapping['output'].name == "_0"
|
||||||
assert mapping['output'].data.is_meta
|
assert mapping['output'].data.is_meta
|
||||||
assert mapping['output'].data.shape == torch.Size([4, 20])
|
assert mapping['output'].data.shape == torch.Size([4, 32])
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
|
||||||
|
strategies_vector = handler.register_strategy()
|
||||||
|
strategy_name_list = [val.name for val in strategies_vector]
|
||||||
|
|
||||||
|
# SS = SR x RS
|
||||||
|
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||||
|
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||||
|
|
||||||
|
# SR = SS x SR
|
||||||
|
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||||
|
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||||
|
|
||||||
|
# RS = RS x SS
|
||||||
|
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||||
|
assert 'RS1 = RS0 x S0S1' in strategy_name_list
|
||||||
|
|
||||||
|
# RR = RS x SR
|
||||||
|
assert 'RR = RS0 x S0R' in strategy_name_list
|
||||||
|
assert 'RR = RS1 x S1R' in strategy_name_list
|
||||||
|
|
||||||
|
# RS= RR x RS
|
||||||
|
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||||
|
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||||
|
|
||||||
|
|
||||||
def test_linear_function_handler():
|
def test_linear_function_handler():
|
||||||
model = nn.Linear(10, 20).to('meta')
|
model = nn.Linear(16, 32).to('meta')
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
|
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
|
||||||
|
@ -77,27 +100,50 @@ def test_linear_function_handler():
|
||||||
|
|
||||||
assert mapping['input'].name == "input_1"
|
assert mapping['input'].name == "input_1"
|
||||||
assert mapping['input'].data.is_meta
|
assert mapping['input'].data.is_meta
|
||||||
assert mapping['input'].data.shape == torch.Size([4, 10])
|
assert mapping['input'].data.shape == torch.Size([4, 16])
|
||||||
assert mapping['input'].type == OperationDataType.ARG
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
assert mapping['input'].logical_shape == torch.Size([4, 10])
|
assert mapping['input'].logical_shape == torch.Size([4, 16])
|
||||||
|
|
||||||
assert mapping['other'].name == "weight"
|
assert mapping['other'].name == "weight"
|
||||||
assert mapping['other'].data.is_meta
|
assert mapping['other'].data.is_meta
|
||||||
assert mapping['other'].data.shape == torch.Size([20, 10])
|
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||||
assert mapping['other'].type == OperationDataType.PARAM
|
assert mapping['other'].type == OperationDataType.PARAM
|
||||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||||
|
|
||||||
assert mapping['bias'].name == "bias"
|
assert mapping['bias'].name == "bias"
|
||||||
assert mapping['bias'].data.is_meta
|
assert mapping['bias'].data.is_meta
|
||||||
assert mapping['bias'].data.shape == torch.Size([20])
|
assert mapping['bias'].data.shape == torch.Size([32])
|
||||||
assert mapping['bias'].type == OperationDataType.PARAM
|
assert mapping['bias'].type == OperationDataType.PARAM
|
||||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||||
|
|
||||||
assert mapping['output'].name == "linear"
|
assert mapping['output'].name == "linear"
|
||||||
assert mapping['output'].data.is_meta
|
assert mapping['output'].data.is_meta
|
||||||
assert mapping['output'].data.shape == torch.Size([4, 20])
|
assert mapping['output'].data.shape == torch.Size([4, 32])
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
|
||||||
|
strategies_vector = handler.register_strategy()
|
||||||
|
strategy_name_list = [val.name for val in strategies_vector]
|
||||||
|
|
||||||
|
# SS = SR x RS
|
||||||
|
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||||
|
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||||
|
|
||||||
|
# SR = SS x SR
|
||||||
|
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||||
|
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||||
|
|
||||||
|
# RS = RS x SS
|
||||||
|
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||||
|
assert 'RS1 = RS0 x S0S1' in strategy_name_list
|
||||||
|
|
||||||
|
# RR = RS x SR
|
||||||
|
assert 'RR = RS0 x S0R' in strategy_name_list
|
||||||
|
assert 'RR = RS1 x S1R' in strategy_name_list
|
||||||
|
|
||||||
|
# RS= RR x RS
|
||||||
|
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||||
|
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_linear_module_handler()
|
test_linear_module_handler()
|
|
@ -1,6 +1,3 @@
|
||||||
from curses import meta
|
|
||||||
from math import dist
|
|
||||||
from xml.dom import HierarchyRequestErr
|
|
||||||
from colossalai.fx.tracer import meta_patch
|
from colossalai.fx.tracer import meta_patch
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
from curses import meta
|
|
||||||
from math import dist
|
|
||||||
from xml.dom import HierarchyRequestErr
|
|
||||||
from colossalai.fx.tracer import meta_patch
|
from colossalai.fx.tracer import meta_patch
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||||
|
|
Loading…
Reference in New Issue