[autoparallel] added strategy generator and bmm strategies (#1602)

pull/1603/head
Frank Lee 2022-09-15 16:57:07 +08:00 committed by GitHub
parent a19eb80998
commit db98b695b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 190 additions and 3 deletions

View File

@ -1,15 +1,168 @@
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
from functools import reduce
from enum import Enum
from .strategy_generator import StrategyGenerator, IntermediateStrategy
from typing import List
__all__ = ['DotHandler']
class MatMulStrategyGenerator(StrategyGenerator):
# TODO: to be implmented
pass
class BatchedMatMulStrategyGenerator(StrategyGenerator):
"""
Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j]
"""
def __init__(self, is_torch_bmm: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_torch_bmm = is_torch_bmm
def split_one_batch_dim(self):
if 1 in self.device_mesh.mesh_shape:
mesh_dim = self.device_mesh.mesh_shape.index(1)
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {
0: [mesh_dim]
},
"bias": {},
"output": {
0: [mesh_dim]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
else:
return None
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_one_batch_dim(self, mesh_dim):
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0]
},
"bias": {},
"output": {
0: mesh_dim_0,
-2: [mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
def generate(self) -> List[IntermediateStrategy]:
strategy_list = []
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
strategy = self.split_one_batch_dim()
if strategy:
strategy_list.append(strategy)
# split batch dim of two inputs and the i dim of the first tensor
# SbSi = SbSi x Sb
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
# split batch dim of two inputs and the j of the second tensor
# SbSj = Sb x SbSj
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
# split batch dim of two inputs and the k dim of two inputs
# Sb = SbSk x SbSk, need to all-reduce by k dim
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
# split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1))
strategy_list.append(self.split_two_batch_dim(1, 0))
return strategy_list
class DotHandler(OperatorHandler):
"""
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear.
"""
def __init__(self, *args, **kwargs):
@ -297,7 +450,7 @@ class DotHandler(OperatorHandler):
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
Generate every possible strategies for a linear node, and record all strategies into the strategies_vector.
Output:

View File

@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
from torch.fx.node import Node
from typing import Dict, List
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .._utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.solver.constants import *

View File

@ -0,0 +1,35 @@
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import List, Dict
from colossalai.device.device_mesh import DeviceMesh
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
@dataclass
class IntermediateStrategy:
"""
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
Args:
name (str): name of the sharding strategy.
dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
"""
name: str
dim_partition_dict: Dict[str, Dict[int, List[int]]]
all_reduce_axis: List[int] = None
class StrategyGenerator(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
"""
def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
@abstractmethod
def generate(self) -> List[IntermediateStrategy]:
pass