from typing import Any, Dict, List, Tuple import torch from torch.fx import Graph, Node from .region import Region from .solver import SolverFactory from .training_simulator import TrainingSimulator from .util import NodeInfo class RegionManager: """ RegionManager is used to construct and manage the offload plan for the model execution. Args: graph (Graph): a Graph object used for analysis and strategy generation. solver_name (str): a solver name which specifies the preferences for plan searching. memory_budget (float): the given memory budget. cnode (List[str], optional): Common node List, should be the subset of input. """ def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None): self.graph = graph assert graph.owning_module is not None, "The given graph is not associated with a owning_module" self.root_module = self.graph.owning_module self.nodes = list(graph.nodes) self.cnode = cnode self.only_param_ops = [] self.param_region_map: Dict[torch.nn.Parameter, Region] = dict() self.shared_region_pairs: List[Tuple[Region, Region]] = list() self.region_list: List[Region] = list() self.rid_in_pool: List[int] = list() self.mem_block_size: int = 0 self.memory_budget = memory_budget self.solver_name = solver_name self.require_pool: bool = solver_name == "asyn" self.reg_to_block: Dict[int, int] = dict() def _build_regions(self): """ 1. Pre-processing, mainly contains linearized computing graph and merge smaller regions into larger ones. 2. Construct a solver to search for an efficient offload strategy. 3. Post-processing, mainly contains early region placement if using asynchronous mode, and initialize region data. """ self._pre_process() solver_cls = SolverFactory.create(self.solver_name) solver = solver_cls(self.region_list, self.memory_budget) solver._call_solver() self._post_process(solver.best_ts) def _pre_process(self): init_region_list = self._linearize_graph() if len(self.shared_region_pairs) > 1: raise NotImplementedError("The current version only considers at most one pair of parameter sharing.") elif len(self.shared_region_pairs) == 1: shared_regs = self.shared_region_pairs[0] assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id fst_id = shared_regs[0].r_id lst_id = shared_regs[1].r_id regs_left_out = init_region_list[: fst_id + 1] regs_right_out = init_region_list[lst_id:] hold_regs = init_region_list[fst_id + 1 : lst_id] else: regs_left_out = [] regs_right_out = [] hold_regs = init_region_list self.mem_block_size = self._search_block_size(hold_regs) hold_regs = self._merge_small_regions(hold_regs) if self.require_pool: for reg in hold_regs: reg.in_mem_pool_flag = True self.rid_in_pool.append(reg.r_id) self.region_list.extend(regs_left_out) self.region_list.extend(hold_regs) for reg in regs_right_out: reg.r_id = self.region_list[-1].r_id + 1 self.region_list[reg.shared_rid].shared_rid = reg.r_id self.region_list.append(reg) self._process_shared_region() self.max_param_num = max([reg.param_num for reg in self.region_list]) self.memory_budget -= self.max_param_num * torch.tensor([], dtype=torch.float32).element_size() def _post_process(self, ts: TrainingSimulator = None): if self.require_pool: self._early_region_placement(ts) self._init_region_data() def _early_region_placement(self, ts: TrainingSimulator): """ Implemented the early region placement strategy to avoid GPU memory fragmentation. It maps all region data into a contiguous memory space and reuses the same memory space for regions that do not coexist. Args: ts (TrainingSimulator): the best training simulator, which records region execution flow. Raises: NotImplementedError: due to the naive implementation, it may not find a suitable region placement strategy for the given execution flow. """ reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow) block_to_regs = {} for block_idx in range(mem_block_num): block_to_regs[block_idx] = [] for reg in self.region_list: if reg.r_id in self.rid_in_pool: cur_reg_appears = coexist_matrix[:, reg.r_id] cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool() for block_idx in range(mem_block_num): if not any(cur_reg_coexists[block_to_regs[block_idx]]): block_to_regs[block_idx].append(reg.r_id) self.reg_to_block[reg.r_id] = block_idx break if reg.r_id not in self.reg_to_block: raise NotImplementedError( f"can not find a block from the memory pool to store parameters of the region" ) self.memory_pool = torch.chunk( torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"), chunks=int(mem_block_num), ) def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: """ Merge smaller regions into larger ones for better bandwidth utilization and easier management. It is inspired by Gemini. Args: orig_reg_list (List[Region]): original region list. Returns: List[Region]: region list after merging. """ r_id = orig_reg_list[0].r_id region = Region(r_id=r_id) region_list = [region] for orig_reg in orig_reg_list: if region_list[-1].param_size + orig_reg.param_size > self.mem_block_size: r_id += 1 region = Region(r_id=r_id) region_list.append(region) region.param_size += orig_reg.param_size region.param_num += orig_reg.param_num region.nodes.extend(orig_reg.nodes) region.fp16_params.extend(orig_reg.fp16_params) self.__update_param_region_map(orig_reg.fp16_params, region) return region_list def _search_block_size( self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2 ) -> int: """ Search for a suitable memory block size. Args: region_list (List[Region]): region list. search_interval_byte (int): searching interval in byte. search_range_byte (int): searching range in byte. Returns: int: the best memory block size. """ def _get_wasted_mem(size_list: List[int], blk_size: int): """ Get wasted byte for a certain block size. """ acc_wasted = 0 left = 0 for s in size_list: if left + s > blk_size: acc_wasted += blk_size - left left = s left += s acc_wasted += blk_size - left return acc_wasted param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid] start_size = max(param_size_list) min_mem_waste = float("+inf") best_block_size = start_size for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): temp_waste = 0 temp_waste += _get_wasted_mem(param_size_list, block_size) if temp_waste < min_mem_waste: min_mem_waste = temp_waste best_block_size = block_size return best_block_size def _init_region_data(self): """ Initialize region data, which maps the parameters in the region to a contiguous memory space. """ self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32) for region in self.region_list: pre_alloc_tensor = None if self.require_pool and region.r_id in self.rid_in_pool: block_idx = self.reg_to_block[region.r_id] pre_alloc_tensor = self.memory_pool[block_idx] if region.r_id <= region.shared_rid: region.init_param_data(pre_alloc_tensor) else: shared_region = self.region_list[region.shared_rid] region.fp16_data = shared_region.fp16_data region.fp32_data = shared_region.fp32_data region.param_to_range = shared_region.param_to_range region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach() torch.cuda.empty_cache() def _process_shared_region(self): """ Special processing for the shared region, which uses GPT2 and Bert case as a priori knowledge. """ if len(self.shared_region_pairs): assert len(self.shared_region_pairs) <= 1 former_reg, latter_reg = self.shared_region_pairs[0] assert latter_reg.param_num >= former_reg.param_num embedding_node = former_reg.nodes[-1] assert embedding_node.op == "call_module" and isinstance( self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding ) if latter_reg.param_num > former_reg.param_num: for idx, n in enumerate(latter_reg.nodes): if ( n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear) ) or (n.op == "call_function" and n.target is torch.nn.functional.linear): cut_node_idx = idx + 1 break assert len(latter_reg.fp16_params) == 2 new_reg = latter_reg.split(cut_node_idx, 1) for p in new_reg.fp16_params: self.param_region_map[p] = new_reg self.region_list.insert(new_reg.r_id, new_reg) for reg in self.region_list[new_reg.r_id + 1 :]: reg.r_id += 1 latter_reg.shared_rid = former_reg.r_id former_reg.shared_rid = latter_reg.r_id def _linearize_graph(self) -> List[Region]: """Linearizing the graph Args: graph (Graph): The computing graph to be optimized. Returns: List[Region]: each region contains the actual 'node' in linearized manner. Remarks: Do merge the inplace ops and shape-consistency ops into the previous node. """ # List of target name that could be seen as common node common_ops = ["getattr", "getitem", "size"] def _is_cop(target: Any) -> bool: """Check if an op could be seen as common node Args: target (Any): node target Returns: bool """ if isinstance(target, str): return target in common_ops else: return target.__name__ in common_ops def _is_act(data: Any) -> bool: """Check if an op could be seen as parameter computation start Args: data (Any): meta_data Returns: bool """ label = False if isinstance(data, torch.Tensor): return True elif isinstance(data, (tuple, list)): for d in data: label = label or _is_act(d) return label def _maybe_param_comp_start() -> bool: """Check if an op could be seen as parameter computation start Args: n (Node): node Returns: bool """ label = False if n.op == "get_attr": label = True elif n.op == "call_module": target = n.target submod = self.root_module.get_submodule(target) if ( len(list(submod.named_parameters(recurse=False))) != 0 or len(list(submod.named_buffers(recurse=False))) != 0 ): label = True return label and not sum([v for _, v in param_op_deps.items()]) def _is_param_comp_end() -> bool: """Check if an op could be seen as parameter computation end Args: n (Node): node Returns: bool """ def _is_inplace(n: Node): """Get the inplace argument from ``torch.fx.Node``""" inplace = False if n.op == "call_function": inplace = n.kwargs.get("inplace", False) elif n.op == "call_module": inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) return inplace label = False if n.op == "call_module": target = n.target submod = self.root_module.get_submodule(target) if ( len(list(submod.named_parameters(recurse=False))) != 0 or len(list(submod.named_buffers(recurse=False))) != 0 ): label = True elif n.op == "call_function": label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any( map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes) ) return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users)) def _exception_node_handling(): # TODO meta info prop bug if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2: n.meta["fwd_out"] = [] # make sure that item in cnode is valid if self.cnode: for name in self.cnode: try: assert ( next(node for node in self.graph.nodes if node.name == name).op == "placeholder" ), f"Common node {name} is not an input of the model." except StopIteration: raise ValueError(f"Common node name {name} not in graph.") else: self.cnode = [] node_id = 0 region_id = 0 param_op_deps = {} deps = {} region_list = [] region = Region(r_id=region_id) act_n = None for n in self.graph.nodes: if n.op != "placeholder" and n.op != "output": for n_par in n.all_input_nodes: if n_par.op != "placeholder" and n_par.name not in self.cnode: deps[n_par] -= 1 if n_par.op != "placeholder" and n_par.name in self.only_param_ops: param_op_deps[n_par] -= 1 if act_n in region.nodes and _maybe_param_comp_start(): ns = [] border_n_idx = region.nodes.index(act_n) if border_n_idx < len(region.nodes): ns = region.nodes[border_n_idx + 1 :] region.nodes = region.nodes[: border_n_idx + 1] region_list.append(region) region_id += 1 region = Region(r_id=region_id) region.nodes = ns _exception_node_handling() region.nodes.append(n) self._set_node_and_region_info(node_id, n, region) node_id += 1 # if the node could free all dependencies in graph # we could begin a new region if _is_param_comp_end(): region_list.append(region) region_id += 1 region = Region(r_id=region_id) # propagate common node attr if possible if len(n.all_input_nodes) == len( [node for node in n.all_input_nodes if node.name in self.cnode] ) or _is_cop(n.target): self.cnode.append(n.name) else: deps[n] = len([user for user in n.users if user.op != "output"]) # propagate param node attr if possible if ( len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops]) or n.op == "get_attr" ): self.only_param_ops.append(n.name) param_op_deps[n] = len([user for user in n.users if user.op != "output"]) # record last activation node if _is_act(n._meta_data): act_n = n if len(region.nodes): region_list.append(region) return region_list def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): cur_n.node_info = NodeInfo(node_id) if cur_n.op == "call_module": target = cur_n.target submod = self.root_module.get_submodule(target) for p in list(submod.parameters(recurse=False)): if p in self.param_region_map: cur_reg.shared_rid = self.param_region_map[p].r_id self.param_region_map[p].shared_rid = cur_reg.r_id self.shared_region_pairs.append((self.param_region_map[p], cur_reg)) else: self.param_region_map[p] = cur_reg cur_reg.fp16_params.append(p) cur_reg.param_num += p.data.numel() cur_reg.param_size += p.data.numel() * p.data.element_size() elif cur_n.op == "get_attr": attr_itr = self.root_module atoms = cur_n.target.split(".") for atom in atoms: attr_itr = getattr(attr_itr, atom) if isinstance(attr_itr, torch.nn.Parameter): if attr_itr in self.param_region_map: cur_reg.shared_rid = self.param_region_map[attr_itr].r_id self.param_region_map[attr_itr].shared_rid = cur_reg.r_id self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg)) else: self.param_region_map[attr_itr] = cur_reg cur_reg.fp16_params.append(attr_itr) cur_reg.param_num += attr_itr.data.numel() cur_reg.param_size += attr_itr.data.numel() * attr_itr.data.element_size() def get_region(self, param: torch.nn.Parameter) -> Region: """ Return the region owning the parameter. Args: param (torch.nn.Parameter): a torch parameter object """ return self.param_region_map[param] def __update_param_region_map(self, params: List[torch.nn.Parameter], region: Region): for p in params: self.param_region_map[p] = region