mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added strategy generator and bmm strategies (#1602)
parent
a19eb80998
commit
db98b695b2
|
@ -1,15 +1,168 @@
|
|||
import operator
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
|
||||
from functools import reduce
|
||||
from enum import Enum
|
||||
from .strategy_generator import StrategyGenerator, IntermediateStrategy
|
||||
from typing import List
|
||||
|
||||
__all__ = ['DotHandler']
|
||||
|
||||
|
||||
class MatMulStrategyGenerator(StrategyGenerator):
|
||||
# TODO: to be implmented
|
||||
pass
|
||||
|
||||
|
||||
class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
||||
"""
|
||||
Generate sharding strategies for the batched matrix multiplication.
|
||||
|
||||
A batched matrix multiplication can be viewed as
|
||||
[b, i, k] x [b, k, j] -> [b, i, j]
|
||||
"""
|
||||
|
||||
def __init__(self, is_torch_bmm: bool, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_torch_bmm = is_torch_bmm
|
||||
|
||||
def split_one_batch_dim(self):
|
||||
if 1 in self.device_mesh.mesh_shape:
|
||||
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
else:
|
||||
return None
|
||||
|
||||
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_one_batch_dim(self, mesh_dim):
|
||||
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0],
|
||||
-2: [mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: mesh_dim_0,
|
||||
-2: [mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"bias": {
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"output": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
|
||||
dim_partition_dict = {
|
||||
"input": {
|
||||
0: [mesh_dim_0],
|
||||
-1: [mesh_dim_1]
|
||||
},
|
||||
"other": {
|
||||
0: [mesh_dim_0],
|
||||
-2: [mesh_dim_1]
|
||||
},
|
||||
"bias": {},
|
||||
"output": {
|
||||
0: [mesh_dim_0],
|
||||
-2: [mesh_dim_1]
|
||||
}
|
||||
}
|
||||
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
|
||||
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
# split only the batch dimension
|
||||
# Sb = Sb x Sb
|
||||
# can be None as it is only for 1D device mesh
|
||||
strategy = self.split_one_batch_dim()
|
||||
if strategy:
|
||||
strategy_list.append(strategy)
|
||||
|
||||
# split batch dim of two inputs and the i dim of the first tensor
|
||||
# SbSi = SbSi x Sb
|
||||
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
|
||||
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
|
||||
|
||||
# split batch dim of two inputs and the j of the second tensor
|
||||
# SbSj = Sb x SbSj
|
||||
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
|
||||
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
|
||||
|
||||
# split batch dim of two inputs and the k dim of two inputs
|
||||
# Sb = SbSk x SbSk, need to all-reduce by k dim
|
||||
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
|
||||
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
|
||||
|
||||
# split two batch dim
|
||||
strategy_list.append(self.split_two_batch_dim(0, 1))
|
||||
strategy_list.append(self.split_two_batch_dim(1, 0))
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class DotHandler(OperatorHandler):
|
||||
"""
|
||||
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
||||
A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -297,7 +450,7 @@ class DotHandler(OperatorHandler):
|
|||
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
'''
|
||||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
||||
Generate every possible strategies for a linear node, and record all strategies into the strategies_vector.
|
||||
|
||||
Output:
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
|
|||
from torch.fx.node import Node
|
||||
from typing import Dict, List
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from .._utils import generate_resharding_costs, generate_sharding_spec
|
||||
from colossalai.auto_parallel.solver.constants import *
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntermediateStrategy:
|
||||
"""
|
||||
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
|
||||
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
|
||||
|
||||
Args:
|
||||
name (str): name of the sharding strategy.
|
||||
dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
|
||||
all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
|
||||
"""
|
||||
name: str
|
||||
dim_partition_dict: Dict[str, Dict[int, List[int]]]
|
||||
all_reduce_axis: List[int] = None
|
||||
|
||||
|
||||
class StrategyGenerator(ABC):
|
||||
"""
|
||||
StrategyGenerator is used to generate the same group of sharding strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, device_mesh: DeviceMesh):
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
@abstractmethod
|
||||
def generate(self) -> List[IntermediateStrategy]:
|
||||
pass
|
Loading…
Reference in New Issue