You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/auto_parallel/offload/region_manager.py

514 lines
20 KiB

from typing import Any, Dict, List, Tuple
import torch
from torch.fx import Graph, Node
from .region import Region
from .solver import SolverFactory
from .training_simulator import TrainingSimulator
from .util import NodeInfo
class RegionManager:
"""
RegionManager is used to construct and manage the offload plan for the model execution.
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
solver_name (str): a solver name which specifies the preferences for plan searching.
memory_budget (float): the given memory budget.
cnode (List[str], optional): Common node List, should be the subset of input.
"""
def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None):
self.graph = graph
assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.cnode = cnode
self.only_param_ops = []
self.param_region_map: Dict[torch.nn.Parameter, Region] = dict()
self.shared_region_pairs: List[Tuple[Region, Region]] = list()
self.region_list: List[Region] = list()
self.rid_in_pool: List[int] = list()
self.mem_block_size: int = 0
self.memory_budget = memory_budget
self.solver_name = solver_name
self.require_pool: bool = solver_name == "asyn"
self.reg_to_block: Dict[int, int] = dict()
def _build_regions(self):
"""
1. Pre-processing, mainly contains linearized computing graph and
merge smaller regions into larger ones.
2. Construct a solver to search for an efficient offload strategy.
3. Post-processing, mainly contains early region placement if using asynchronous mode,
and initialize region data.
"""
self._pre_process()
solver_cls = SolverFactory.create(self.solver_name)
solver = solver_cls(self.region_list, self.memory_budget)
solver._call_solver()
self._post_process(solver.best_ts)
def _pre_process(self):
init_region_list = self._linearize_graph()
if len(self.shared_region_pairs) > 1:
raise NotImplementedError("The current version only considers at most one pair of parameter sharing.")
elif len(self.shared_region_pairs) == 1:
shared_regs = self.shared_region_pairs[0]
assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id
fst_id = shared_regs[0].r_id
lst_id = shared_regs[1].r_id
regs_left_out = init_region_list[: fst_id + 1]
regs_right_out = init_region_list[lst_id:]
hold_regs = init_region_list[fst_id + 1 : lst_id]
else:
regs_left_out = []
regs_right_out = []
hold_regs = init_region_list
self.mem_block_size = self._search_block_size(hold_regs)
hold_regs = self._merge_small_regions(hold_regs)
if self.require_pool:
for reg in hold_regs:
reg.in_mem_pool_flag = True
self.rid_in_pool.append(reg.r_id)
self.region_list.extend(regs_left_out)
self.region_list.extend(hold_regs)
for reg in regs_right_out:
reg.r_id = self.region_list[-1].r_id + 1
self.region_list[reg.shared_rid].shared_rid = reg.r_id
self.region_list.append(reg)
self._process_shared_region()
self.max_param_num = max([reg.param_num for reg in self.region_list])
self.memory_budget -= self.max_param_num * torch.tensor([], dtype=torch.float32).element_size()
def _post_process(self, ts: TrainingSimulator = None):
if self.require_pool:
self._early_region_placement(ts)
self._init_region_data()
def _early_region_placement(self, ts: TrainingSimulator):
"""
Implemented the early region placement strategy to avoid GPU memory fragmentation.
It maps all region data into a contiguous memory space and
reuses the same memory space for regions that do not coexist.
Args:
ts (TrainingSimulator): the best training simulator, which records region execution flow.
Raises:
NotImplementedError: due to the naive implementation,
it may not find a suitable region placement strategy for the given execution flow.
"""
reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow)
block_to_regs = {}
for block_idx in range(mem_block_num):
block_to_regs[block_idx] = []
for reg in self.region_list:
if reg.r_id in self.rid_in_pool:
cur_reg_appears = coexist_matrix[:, reg.r_id]
cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool()
for block_idx in range(mem_block_num):
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
block_to_regs[block_idx].append(reg.r_id)
self.reg_to_block[reg.r_id] = block_idx
break
if reg.r_id not in self.reg_to_block:
raise NotImplementedError(
f"can not find a block from the memory pool to store parameters of the region"
)
self.memory_pool = torch.chunk(
torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"),
chunks=int(mem_block_num),
)
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
"""
Merge smaller regions into larger ones for better bandwidth utilization and easier management.
It is inspired by Gemini.
Args:
orig_reg_list (List[Region]): original region list.
Returns:
List[Region]: region list after merging.
"""
r_id = orig_reg_list[0].r_id
region = Region(r_id=r_id)
region_list = [region]
for orig_reg in orig_reg_list:
if region_list[-1].param_size + orig_reg.param_size > self.mem_block_size:
r_id += 1
region = Region(r_id=r_id)
region_list.append(region)
region.param_size += orig_reg.param_size
region.param_num += orig_reg.param_num
region.nodes.extend(orig_reg.nodes)
region.fp16_params.extend(orig_reg.fp16_params)
self.__update_param_region_map(orig_reg.fp16_params, region)
return region_list
def _search_block_size(
self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2
) -> int:
"""
Search for a suitable memory block size.
Args:
region_list (List[Region]): region list.
search_interval_byte (int): searching interval in byte.
search_range_byte (int): searching range in byte.
Returns:
int: the best memory block size.
"""
def _get_wasted_mem(size_list: List[int], blk_size: int):
"""
Get wasted byte for a certain block size.
"""
acc_wasted = 0
left = 0
for s in size_list:
if left + s > blk_size:
acc_wasted += blk_size - left
left = s
left += s
acc_wasted += blk_size - left
return acc_wasted
param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid]
start_size = max(param_size_list)
min_mem_waste = float("+inf")
best_block_size = start_size
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
temp_waste = 0
temp_waste += _get_wasted_mem(param_size_list, block_size)
if temp_waste < min_mem_waste:
min_mem_waste = temp_waste
best_block_size = block_size
return best_block_size
def _init_region_data(self):
"""
Initialize region data, which maps the parameters in the region to a contiguous memory space.
"""
self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32)
for region in self.region_list:
pre_alloc_tensor = None
if self.require_pool and region.r_id in self.rid_in_pool:
block_idx = self.reg_to_block[region.r_id]
pre_alloc_tensor = self.memory_pool[block_idx]
if region.r_id <= region.shared_rid:
region.init_param_data(pre_alloc_tensor)
else:
shared_region = self.region_list[region.shared_rid]
region.fp16_data = shared_region.fp16_data
region.fp32_data = shared_region.fp32_data
region.param_to_range = shared_region.param_to_range
region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach()
torch.cuda.empty_cache()
def _process_shared_region(self):
"""
Special processing for the shared region, which uses GPT2 and Bert case as a priori knowledge.
"""
if len(self.shared_region_pairs):
assert len(self.shared_region_pairs) <= 1
former_reg, latter_reg = self.shared_region_pairs[0]
assert latter_reg.param_num >= former_reg.param_num
embedding_node = former_reg.nodes[-1]
assert embedding_node.op == "call_module" and isinstance(
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding
)
if latter_reg.param_num > former_reg.param_num:
for idx, n in enumerate(latter_reg.nodes):
if (
n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear)
) or (n.op == "call_function" and n.target is torch.nn.functional.linear):
cut_node_idx = idx + 1
break
assert len(latter_reg.fp16_params) == 2
new_reg = latter_reg.split(cut_node_idx, 1)
for p in new_reg.fp16_params:
self.param_region_map[p] = new_reg
self.region_list.insert(new_reg.r_id, new_reg)
for reg in self.region_list[new_reg.r_id + 1 :]:
reg.r_id += 1
latter_reg.shared_rid = former_reg.r_id
former_reg.shared_rid = latter_reg.r_id
def _linearize_graph(self) -> List[Region]:
"""Linearizing the graph
Args:
graph (Graph): The computing graph to be optimized.
Returns:
List[Region]: each region contains the actual 'node' in linearized manner.
Remarks:
Do merge the inplace ops and shape-consistency ops into the previous node.
"""
# List of target name that could be seen as common node
common_ops = ["getattr", "getitem", "size"]
def _is_cop(target: Any) -> bool:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if isinstance(target, str):
return target in common_ops
else:
return target.__name__ in common_ops
def _is_act(data: Any) -> bool:
"""Check if an op could be seen as parameter computation start
Args:
data (Any): meta_data
Returns:
bool
"""
label = False
if isinstance(data, torch.Tensor):
return True
elif isinstance(data, (tuple, list)):
for d in data:
label = label or _is_act(d)
return label
def _maybe_param_comp_start() -> bool:
"""Check if an op could be seen as parameter computation start
Args:
n (Node): node
Returns:
bool
"""
label = False
if n.op == "get_attr":
label = True
elif n.op == "call_module":
target = n.target
submod = self.root_module.get_submodule(target)
if (
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
return label and not sum([v for _, v in param_op_deps.items()])
def _is_param_comp_end() -> bool:
"""Check if an op could be seen as parameter computation end
Args:
n (Node): node
Returns:
bool
"""
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
label = False
if n.op == "call_module":
target = n.target
submod = self.root_module.get_submodule(target)
if (
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
elif n.op == "call_function":
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)
)
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
def _exception_node_handling():
# TODO meta info prop bug
if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2:
n.meta["fwd_out"] = []
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert (
next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
self.cnode = []
node_id = 0
region_id = 0
param_op_deps = {}
deps = {}
region_list = []
region = Region(r_id=region_id)
act_n = None
for n in self.graph.nodes:
if n.op != "placeholder" and n.op != "output":
for n_par in n.all_input_nodes:
if n_par.op != "placeholder" and n_par.name not in self.cnode:
deps[n_par] -= 1
if n_par.op != "placeholder" and n_par.name in self.only_param_ops:
param_op_deps[n_par] -= 1
if act_n in region.nodes and _maybe_param_comp_start():
ns = []
border_n_idx = region.nodes.index(act_n)
if border_n_idx < len(region.nodes):
ns = region.nodes[border_n_idx + 1 :]
region.nodes = region.nodes[: border_n_idx + 1]
region_list.append(region)
region_id += 1
region = Region(r_id=region_id)
region.nodes = ns
_exception_node_handling()
region.nodes.append(n)
self._set_node_and_region_info(node_id, n, region)
node_id += 1
# if the node could free all dependencies in graph
# we could begin a new region
if _is_param_comp_end():
region_list.append(region)
region_id += 1
region = Region(r_id=region_id)
# propagate common node attr if possible
if len(n.all_input_nodes) == len(
[node for node in n.all_input_nodes if node.name in self.cnode]
) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
# propagate param node attr if possible
if (
len(n.all_input_nodes)
== len([node for node in n.all_input_nodes if node.name in self.only_param_ops])
or n.op == "get_attr"
):
self.only_param_ops.append(n.name)
param_op_deps[n] = len([user for user in n.users if user.op != "output"])
# record last activation node
if _is_act(n._meta_data):
act_n = n
if len(region.nodes):
region_list.append(region)
return region_list
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
cur_n.node_info = NodeInfo(node_id)
if cur_n.op == "call_module":
target = cur_n.target
submod = self.root_module.get_submodule(target)
for p in list(submod.parameters(recurse=False)):
if p in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[p].r_id
self.param_region_map[p].shared_rid = cur_reg.r_id
self.shared_region_pairs.append((self.param_region_map[p], cur_reg))
else:
self.param_region_map[p] = cur_reg
cur_reg.fp16_params.append(p)
cur_reg.param_num += p.data.numel()
cur_reg.param_size += p.data.numel() * p.data.element_size()
elif cur_n.op == "get_attr":
attr_itr = self.root_module
atoms = cur_n.target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.Parameter):
if attr_itr in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg))
else:
self.param_region_map[attr_itr] = cur_reg
cur_reg.fp16_params.append(attr_itr)
cur_reg.param_num += attr_itr.data.numel()
cur_reg.param_size += attr_itr.data.numel() * attr_itr.data.element_size()
def get_region(self, param: torch.nn.Parameter) -> Region:
"""
Return the region owning the parameter.
Args:
param (torch.nn.Parameter): a torch parameter object
"""
return self.param_region_map[param]
def __update_param_region_map(self, params: List[torch.nn.Parameter], region: Region):
for p in params:
self.param_region_map[p] = region