mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
527 lines
20 KiB
527 lines
20 KiB
2 years ago
|
from typing import List, Any, Dict, Tuple
|
||
|
import torch
|
||
|
from torch.fx import Graph, Node
|
||
|
|
||
|
from .solver import SolverFactory
|
||
|
from .training_simulator import TrainingSimulator
|
||
|
from .region import Region
|
||
|
from .util import NodeInfo
|
||
|
|
||
|
|
||
|
class RegionManager:
|
||
|
"""
|
||
|
RegionManager is used to construct and manage the offload plan for the model execution.
|
||
|
|
||
|
Args:
|
||
|
graph (Graph): a Graph object used for analysis and strategy generation.
|
||
|
solver_name (str): a solver name which specifies the preferences for plan searching.
|
||
|
memory_budget (float): the given memory budget.
|
||
|
cnode (List[str], optional): Common node List, should be the subset of input.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
graph: Graph,
|
||
|
solver_name: str = 'asyn',
|
||
|
memory_budget: float = -1.0,
|
||
|
cnode: List[str] = None):
|
||
|
|
||
|
self.graph = graph
|
||
|
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||
|
self.root_module = self.graph.owning_module
|
||
|
self.nodes = list(graph.nodes)
|
||
|
self.cnode = cnode
|
||
|
self.only_param_ops = []
|
||
|
self.param_region_map: Dict[torch.nn.Parameter, Region] = dict()
|
||
|
self.shared_region_pairs: List[Tuple[Region, Region]] = list()
|
||
|
self.region_list: List[Region] = list()
|
||
|
self.rid_in_pool: List[int] = list()
|
||
|
self.mem_block_size: int = 0
|
||
|
self.memory_budget = memory_budget
|
||
|
|
||
|
self.solver_name = solver_name
|
||
|
self.require_pool: bool = solver_name == 'asyn'
|
||
|
|
||
|
self.reg_to_block: Dict[int, int] = dict()
|
||
|
|
||
|
def _build_regions(self):
|
||
|
"""
|
||
|
1. Pre-processing, mainly contains linearized computing graph and
|
||
|
merge smaller regions into larger ones.
|
||
|
2. Construct a solver to search for an efficient offload strategy.
|
||
|
3. Post-processing, mainly contains early region placement if using asynchronous mode,
|
||
|
and initialize region data.
|
||
|
"""
|
||
|
|
||
|
self._pre_process()
|
||
|
|
||
|
solver_cls = SolverFactory.create(self.solver_name)
|
||
|
solver = solver_cls(self.region_list, self.memory_budget)
|
||
|
solver._call_solver()
|
||
|
|
||
|
self._post_process(solver.best_ts)
|
||
|
|
||
|
def _pre_process(self):
|
||
|
|
||
|
init_region_list = self._linearize_graph()
|
||
|
|
||
|
if len(self.shared_region_pairs) > 1:
|
||
|
raise NotImplementedError(
|
||
|
'The current version only considers at most one pair of parameter sharing.')
|
||
|
|
||
|
elif len(self.shared_region_pairs) == 1:
|
||
|
shared_regs = self.shared_region_pairs[0]
|
||
|
assert shared_regs[0].shared_rid == shared_regs[1].r_id \
|
||
|
and shared_regs[1].shared_rid == shared_regs[0].r_id
|
||
|
fst_id = shared_regs[0].r_id
|
||
|
lst_id = shared_regs[1].r_id
|
||
|
regs_left_out = init_region_list[:fst_id + 1]
|
||
|
regs_right_out = init_region_list[lst_id:]
|
||
|
hold_regs = init_region_list[fst_id + 1:lst_id]
|
||
|
else:
|
||
|
regs_left_out = []
|
||
|
regs_right_out = []
|
||
|
hold_regs = init_region_list
|
||
|
|
||
|
self.mem_block_size = self._search_block_size(hold_regs)
|
||
|
hold_regs = self._merge_small_regions(hold_regs)
|
||
|
|
||
|
if self.require_pool:
|
||
|
for reg in hold_regs:
|
||
|
reg.in_mem_pool_flag = True
|
||
|
self.rid_in_pool.append(reg.r_id)
|
||
|
|
||
|
self.region_list.extend(regs_left_out)
|
||
|
self.region_list.extend(hold_regs)
|
||
|
|
||
|
for reg in regs_right_out:
|
||
|
reg.r_id = self.region_list[-1].r_id + 1
|
||
|
self.region_list[reg.shared_rid].shared_rid = reg.r_id
|
||
|
self.region_list.append(reg)
|
||
|
|
||
|
self._process_shared_region()
|
||
|
|
||
|
self.max_param_num = max([reg.param_num for reg in self.region_list])
|
||
|
self.memory_budget -= self.max_param_num * torch.tensor([], dtype=torch.float32).element_size()
|
||
|
|
||
|
def _post_process(self, ts: TrainingSimulator = None):
|
||
|
if self.require_pool:
|
||
|
self._early_region_placement(ts)
|
||
|
self._init_region_data()
|
||
|
|
||
|
def _early_region_placement(self, ts: TrainingSimulator):
|
||
|
"""
|
||
|
Implemented the early region placement strategy to avoid GPU memory fragmentation.
|
||
|
It maps all region data into a contiguous memory space and
|
||
|
reuses the same memory space for regions that do not coexist.
|
||
|
|
||
|
Args:
|
||
|
ts (TrainingSimulator): the best training simulator, which records region execution flow.
|
||
|
|
||
|
Raises:
|
||
|
NotImplementedError: due to the naive implementation,
|
||
|
it may not find a suitable region placement strategy for the given execution flow.
|
||
|
"""
|
||
|
|
||
|
reg_flow = torch.cat(
|
||
|
[ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
|
||
|
mem_block_num = torch.max(
|
||
|
torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
|
||
|
coexist_matrix = torch.logical_or(
|
||
|
ts.fwd_reg_flow, ts.bwd_reg_flow)
|
||
|
|
||
|
block_to_regs = {}
|
||
|
for block_idx in range(mem_block_num):
|
||
|
block_to_regs[block_idx] = []
|
||
|
for reg in self.region_list:
|
||
|
if reg.r_id in self.rid_in_pool:
|
||
|
cur_reg_appears = coexist_matrix[:, reg.r_id]
|
||
|
cur_reg_coexists = torch.sum(
|
||
|
coexist_matrix[cur_reg_appears], dim=0).bool()
|
||
|
for block_idx in range(mem_block_num):
|
||
|
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
|
||
|
block_to_regs[block_idx].append(reg.r_id)
|
||
|
self.reg_to_block[reg.r_id] = block_idx
|
||
|
break
|
||
|
|
||
|
if reg.r_id not in self.reg_to_block:
|
||
|
raise NotImplementedError(
|
||
|
f'can not find a block from the memory pool to store parameters of the region')
|
||
|
self.memory_pool = torch.chunk(torch.zeros(int(
|
||
|
mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
|
||
|
|
||
|
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
|
||
|
"""
|
||
|
Merge smaller regions into larger ones for better bandwidth utilization and easier management.
|
||
|
It is inspired by Gemini.
|
||
|
|
||
|
Args:
|
||
|
orig_reg_list (List[Region]): original region list.
|
||
|
|
||
|
Returns:
|
||
|
List[Region]: region list after merging.
|
||
|
"""
|
||
|
|
||
|
r_id = orig_reg_list[0].r_id
|
||
|
region = Region(r_id=r_id)
|
||
|
region_list = [region]
|
||
|
|
||
|
for orig_reg in orig_reg_list:
|
||
|
if region_list[-1].param_size + orig_reg.param_size > self.mem_block_size:
|
||
|
r_id += 1
|
||
|
region = Region(r_id=r_id)
|
||
|
region_list.append(region)
|
||
|
region.param_size += orig_reg.param_size
|
||
|
region.param_num += orig_reg.param_num
|
||
|
region.nodes.extend(orig_reg.nodes)
|
||
|
region.fp16_params.extend(orig_reg.fp16_params)
|
||
|
self.__update_param_region_map(orig_reg.fp16_params, region)
|
||
|
|
||
|
return region_list
|
||
|
|
||
|
def _search_block_size(self,
|
||
|
region_list: List[Region],
|
||
|
search_interval_byte: int = 1024,
|
||
|
search_range_byte: int = 128 * 1024 ** 2) -> int:
|
||
|
"""
|
||
|
Search for a suitable memory block size.
|
||
|
|
||
|
Args:
|
||
|
region_list (List[Region]): region list.
|
||
|
search_interval_byte (int): searching interval in byte.
|
||
|
search_range_byte (int): searching range in byte.
|
||
|
|
||
|
Returns:
|
||
|
int: the best memory block size.
|
||
|
"""
|
||
|
|
||
|
def _get_wasted_mem(size_list: List[int], blk_size: int):
|
||
|
"""
|
||
|
Get wasted byte for a certain block size.
|
||
|
"""
|
||
|
acc_wasted = 0
|
||
|
left = 0
|
||
|
for s in size_list:
|
||
|
if left + s > blk_size:
|
||
|
acc_wasted += blk_size - left
|
||
|
left = s
|
||
|
left += s
|
||
|
acc_wasted += blk_size - left
|
||
|
return acc_wasted
|
||
|
|
||
|
param_size_list = [
|
||
|
region.param_size for region in region_list if region.r_id == region.shared_rid]
|
||
|
|
||
|
start_size = max(param_size_list)
|
||
|
min_mem_waste = float('+inf')
|
||
|
best_block_size = start_size
|
||
|
|
||
|
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
|
||
|
temp_waste = 0
|
||
|
temp_waste += _get_wasted_mem(param_size_list, block_size)
|
||
|
if temp_waste < min_mem_waste:
|
||
|
min_mem_waste = temp_waste
|
||
|
best_block_size = block_size
|
||
|
|
||
|
return best_block_size
|
||
|
|
||
|
def _init_region_data(self):
|
||
|
"""
|
||
|
Initialize region data, which maps the parameters in the region to a contiguous memory space.
|
||
|
"""
|
||
|
|
||
|
self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32)
|
||
|
|
||
|
for region in self.region_list:
|
||
|
pre_alloc_tensor = None
|
||
|
if self.require_pool and region.r_id in self.rid_in_pool:
|
||
|
block_idx = self.reg_to_block[region.r_id]
|
||
|
pre_alloc_tensor = self.memory_pool[block_idx]
|
||
|
|
||
|
if region.r_id <= region.shared_rid:
|
||
|
region.init_param_data(pre_alloc_tensor)
|
||
|
else:
|
||
|
shared_region = self.region_list[region.shared_rid]
|
||
|
region.fp16_data = shared_region.fp16_data
|
||
|
region.fp32_data = shared_region.fp32_data
|
||
|
region.param_to_range = shared_region.param_to_range
|
||
|
region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
|
||
|
)
|
||
|
|
||
|
torch.cuda.empty_cache()
|
||
|
|
||
|
def _process_shared_region(self):
|
||
|
"""
|
||
|
Special processing for the shared region, which uses GPT2 and Bert case as a priori knowledge.
|
||
|
"""
|
||
|
|
||
|
if len(self.shared_region_pairs):
|
||
|
assert len(self.shared_region_pairs) <= 1
|
||
|
former_reg, latter_reg = self.shared_region_pairs[0]
|
||
|
assert latter_reg.param_num >= former_reg.param_num
|
||
|
embedding_node = former_reg.nodes[-1]
|
||
|
assert embedding_node.op == 'call_module' and isinstance(
|
||
|
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding)
|
||
|
if latter_reg.param_num > former_reg.param_num:
|
||
|
for idx, n in enumerate(latter_reg.nodes):
|
||
|
if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target),
|
||
|
torch.nn.Linear)) or \
|
||
|
(n.op == 'call_function' and n.target is torch.nn.functional.linear):
|
||
|
cut_node_idx = idx + 1
|
||
|
break
|
||
|
assert len(latter_reg.fp16_params) == 2
|
||
|
new_reg = latter_reg.split(cut_node_idx, 1)
|
||
|
for p in new_reg.fp16_params:
|
||
|
self.param_region_map[p] = new_reg
|
||
|
self.region_list.insert(new_reg.r_id, new_reg)
|
||
|
for reg in self.region_list[new_reg.r_id + 1:]:
|
||
|
reg.r_id += 1
|
||
|
latter_reg.shared_rid = former_reg.r_id
|
||
|
former_reg.shared_rid = latter_reg.r_id
|
||
|
|
||
|
def _linearize_graph(self) -> List[Region]:
|
||
|
"""Linearizing the graph
|
||
|
|
||
|
Args:
|
||
|
graph (Graph): The computing graph to be optimized.
|
||
|
|
||
|
Returns:
|
||
|
List[Region]: each region contains the actual 'node' in linearized manner.
|
||
|
|
||
|
Remarks:
|
||
|
Do merge the inplace ops and shape-consistency ops into the previous node.
|
||
|
"""
|
||
|
|
||
|
# List of target name that could be seen as common node
|
||
|
common_ops = ["getattr", "getitem", "size"]
|
||
|
|
||
|
def _is_cop(target: Any) -> bool:
|
||
|
"""Check if an op could be seen as common node
|
||
|
|
||
|
Args:
|
||
|
target (Any): node target
|
||
|
|
||
|
Returns:
|
||
|
bool
|
||
|
"""
|
||
|
|
||
|
if isinstance(target, str):
|
||
|
return target in common_ops
|
||
|
else:
|
||
|
return target.__name__ in common_ops
|
||
|
|
||
|
def _is_act(data: Any) -> bool:
|
||
|
"""Check if an op could be seen as parameter computation start
|
||
|
|
||
|
Args:
|
||
|
data (Any): meta_data
|
||
|
|
||
|
Returns:
|
||
|
bool
|
||
|
"""
|
||
|
|
||
|
label = False
|
||
|
if isinstance(data, torch.Tensor):
|
||
|
return True
|
||
|
elif isinstance(data, (tuple, list)):
|
||
|
for d in data:
|
||
|
label = label or _is_act(d)
|
||
|
return label
|
||
|
|
||
|
def _maybe_param_comp_start() -> bool:
|
||
|
"""Check if an op could be seen as parameter computation start
|
||
|
|
||
|
Args:
|
||
|
n (Node): node
|
||
|
|
||
|
Returns:
|
||
|
bool
|
||
|
"""
|
||
|
|
||
|
label = False
|
||
|
if n.op == "get_attr":
|
||
|
label = True
|
||
|
elif n.op == "call_module":
|
||
|
target = n.target
|
||
|
submod = self.root_module.get_submodule(target)
|
||
|
if (
|
||
|
len(list(submod.named_parameters(recurse=False))) != 0
|
||
|
or len(list(submod.named_buffers(recurse=False))) != 0
|
||
|
):
|
||
|
label = True
|
||
|
|
||
|
return label and not sum([v for _, v in param_op_deps.items()])
|
||
|
|
||
|
def _is_param_comp_end() -> bool:
|
||
|
"""Check if an op could be seen as parameter computation end
|
||
|
|
||
|
Args:
|
||
|
n (Node): node
|
||
|
|
||
|
Returns:
|
||
|
bool
|
||
|
"""
|
||
|
|
||
|
def _is_inplace(n: Node):
|
||
|
"""Get the inplace argument from ``torch.fx.Node``
|
||
|
"""
|
||
|
inplace = False
|
||
|
if n.op == "call_function":
|
||
|
inplace = n.kwargs.get("inplace", False)
|
||
|
elif n.op == "call_module":
|
||
|
inplace = getattr(n.graph.owning_module.get_submodule(
|
||
|
n.target), "inplace", False)
|
||
|
return inplace
|
||
|
|
||
|
label = False
|
||
|
|
||
|
if n.op == "call_module":
|
||
|
target = n.target
|
||
|
submod = self.root_module.get_submodule(target)
|
||
|
if (
|
||
|
len(list(submod.named_parameters(recurse=False))) != 0
|
||
|
or len(list(submod.named_buffers(recurse=False))) != 0
|
||
|
):
|
||
|
label = True
|
||
|
|
||
|
elif n.op == "call_function":
|
||
|
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
|
||
|
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes))
|
||
|
|
||
|
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
|
||
|
|
||
|
def _exception_node_handling():
|
||
|
# TODO meta info prop bug
|
||
|
if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2:
|
||
|
n.meta['fwd_out'] = []
|
||
|
|
||
|
# make sure that item in cnode is valid
|
||
|
if self.cnode:
|
||
|
for name in self.cnode:
|
||
|
try:
|
||
|
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
|
||
|
f"Common node {name} is not an input of the model."
|
||
|
except StopIteration:
|
||
|
raise ValueError(f"Common node name {name} not in graph.")
|
||
|
else:
|
||
|
self.cnode = []
|
||
|
|
||
|
node_id = 0
|
||
|
region_id = 0
|
||
|
|
||
|
param_op_deps = {}
|
||
|
|
||
|
deps = {}
|
||
|
region_list = []
|
||
|
region = Region(r_id=region_id)
|
||
|
|
||
|
act_n = None
|
||
|
|
||
|
for n in self.graph.nodes:
|
||
|
if n.op != "placeholder" and n.op != "output":
|
||
|
for n_par in n.all_input_nodes:
|
||
|
if n_par.op != "placeholder" and n_par.name not in self.cnode:
|
||
|
deps[n_par] -= 1
|
||
|
if n_par.op != "placeholder" and n_par.name in self.only_param_ops:
|
||
|
param_op_deps[n_par] -= 1
|
||
|
|
||
|
if act_n in region.nodes and _maybe_param_comp_start():
|
||
|
ns = []
|
||
|
border_n_idx = region.nodes.index(act_n)
|
||
|
if border_n_idx < len(region.nodes):
|
||
|
ns = region.nodes[border_n_idx + 1:]
|
||
|
region.nodes = region.nodes[:border_n_idx + 1]
|
||
|
region_list.append(region)
|
||
|
region_id += 1
|
||
|
region = Region(r_id=region_id)
|
||
|
region.nodes = ns
|
||
|
|
||
|
_exception_node_handling()
|
||
|
region.nodes.append(n)
|
||
|
self._set_node_and_region_info(node_id, n, region)
|
||
|
node_id += 1
|
||
|
|
||
|
# if the node could free all dependencies in graph
|
||
|
# we could begin a new region
|
||
|
if _is_param_comp_end():
|
||
|
region_list.append(region)
|
||
|
region_id += 1
|
||
|
region = Region(r_id=region_id)
|
||
|
|
||
|
# propagate common node attr if possible
|
||
|
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
|
||
|
]) or _is_cop(n.target):
|
||
|
self.cnode.append(n.name)
|
||
|
else:
|
||
|
deps[n] = len(
|
||
|
[user for user in n.users if user.op != "output"])
|
||
|
|
||
|
# propagate param node attr if possible
|
||
|
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
|
||
|
]) or n.op == "get_attr":
|
||
|
self.only_param_ops.append(n.name)
|
||
|
param_op_deps[n] = len(
|
||
|
[user for user in n.users if user.op != "output"])
|
||
|
|
||
|
# record last activation node
|
||
|
if _is_act(n._meta_data):
|
||
|
act_n = n
|
||
|
|
||
|
if len(region.nodes):
|
||
|
region_list.append(region)
|
||
|
|
||
|
return region_list
|
||
|
|
||
|
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
|
||
|
|
||
|
cur_n.node_info = NodeInfo(node_id)
|
||
|
|
||
|
if cur_n.op == 'call_module':
|
||
|
target = cur_n.target
|
||
|
submod = self.root_module.get_submodule(target)
|
||
|
for p in list(submod.parameters(recurse=False)):
|
||
|
|
||
|
if p in self.param_region_map:
|
||
|
cur_reg.shared_rid = self.param_region_map[p].r_id
|
||
|
self.param_region_map[p].shared_rid = cur_reg.r_id
|
||
|
self.shared_region_pairs.append(
|
||
|
(self.param_region_map[p], cur_reg))
|
||
|
else:
|
||
|
self.param_region_map[p] = cur_reg
|
||
|
|
||
|
cur_reg.fp16_params.append(p)
|
||
|
cur_reg.param_num += p.data.numel()
|
||
|
cur_reg.param_size += p.data.numel() * p.data.element_size()
|
||
|
|
||
|
elif cur_n.op == "get_attr":
|
||
|
attr_itr = self.root_module
|
||
|
atoms = cur_n.target.split(".")
|
||
|
for atom in atoms:
|
||
|
attr_itr = getattr(attr_itr, atom)
|
||
|
|
||
|
if isinstance(attr_itr, torch.nn.Parameter):
|
||
|
|
||
|
if attr_itr in self.param_region_map:
|
||
|
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
|
||
|
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
|
||
|
self.shared_region_pairs.append(
|
||
|
(self.param_region_map[attr_itr], cur_reg))
|
||
|
else:
|
||
|
self.param_region_map[attr_itr] = cur_reg
|
||
|
|
||
|
cur_reg.fp16_params.append(attr_itr)
|
||
|
cur_reg.param_num += attr_itr.data.numel()
|
||
|
cur_reg.param_size += attr_itr.data.numel() * attr_itr.data.element_size()
|
||
|
|
||
|
def get_region(self, param: torch.nn.Parameter) -> Region:
|
||
|
"""
|
||
|
Return the region owning the parameter.
|
||
|
|
||
|
Args:
|
||
|
param (torch.nn.Parameter): a torch parameter object
|
||
|
"""
|
||
|
return self.param_region_map[param]
|
||
|
|
||
|
def __update_param_region_map(self, params: List[torch.nn.Parameter], region: Region):
|
||
|
for p in params:
|
||
|
self.param_region_map[p] = region
|