import time from typing import List, Dict, Type from abc import ABC, abstractmethod NOT_NVML = False try: from pynvml import * except: NOT_NVML = True import torch from torch.fx.node import Node from colossalai.utils.cuda import get_current_device from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator from .region import Region from .util import NodeInfo, NvDevicePower def benchmark_func(func, number=1, repeat=1, warmup=3): """ benchmark data transfer cost. """ for i in range(warmup): func() costs = [] for i in range(repeat): torch.cuda.synchronize() begin = time.time() for i in range(number): func() torch.cuda.synchronize() costs.append((time.time() - begin) / number) return sum(costs) / len(costs) class Solver(ABC): """ The parameter offload solver. Args: region_list (List[Region]): represents the linearized DNN computing graph. memory_budget (float): the given memory budget. error_factor (float): the error factor. It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time. """ def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None: self.region_list = region_list self.error_factor: float = error_factor if memory_budget > 0: self.memory_budget = memory_budget * self.error_factor else: self.memory_budget = torch.cuda.get_device_properties( get_current_device()).total_memory * self.error_factor self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.comp_power: float = self._extract_computing_power() @abstractmethod def _call_solver(self): raise NotImplementedError @abstractmethod def _try_to_offload(self, *args): raise NotImplementedError @abstractmethod def _eval_one_choice(self, *args): raise NotImplementedError def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: float, extra_cost: float): """ Compute the profits of the offload strategies, which packages the memory savings information for subsequent comparisons. Args: total_mem_saving (float): the total memory saving of the offload strategy. peak_mem_saving (float): the peak memory saving of the offload strategy. extra_cost (float): extra data transfer cost. Returns: tuple: profit information, the first term represents memory savings per unit of time. """ if extra_cost == 0: # means data transfer overhead can be completely overlapped return (float('inf'), total_mem_saving, peak_mem_saving) return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving) def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool: """ Compare the profits of the two offload strategies using the dictionary order algorithm. Args: profit_a (tuple): the profit of a offload strategy. profit_b (tuple): the profit of another offload strategy. Returns: bool: whether profit_a is greater than profit_b. """ for val1, val2 in zip(profit_a, profit_b): if val1 != val2: return val1 > val2 return False def _update_state(self, best_ts: TrainingSimulator): """ Update the solver state. """ self.best_ts = best_ts self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem) def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]): """ Update the runtime memory information of the node. Args: fwd_mem_info (Dict[Node, float]): the runtime memory of each node in forward pass. bwd_mem_info (Dict[Node, float]): the runtime memory of each node in backward pass. """ for node, mem in fwd_mem_info.items(): assert hasattr(node, 'node_info') and isinstance( node.node_info, NodeInfo) node.node_info.runtime_fwd_mem = mem for node, mem in bwd_mem_info.items(): assert hasattr(node, 'node_info') and isinstance( node.node_info, NodeInfo) node.node_info.runtime_bwd_mem = mem def _extract_computing_power(self): """ return the FP16 computing performance of the current NVIDIA GPU. Raises: TypeError: Unknown NVIDIA GPU device. """ nvmlInit() handle = nvmlDeviceGetHandleByIndex(0) device_name = nvmlDeviceGetName(handle) units = 1e12 if device_name.__contains__("RTX 3080"): return NvDevicePower.RTX3080_FP16 * units elif device_name.__contains__("RTX 3090"): return NvDevicePower.RTX3090_FP16 * units elif device_name.__contains__('V100'): return NvDevicePower.V100_FP16 * units elif device_name.__contains__("A100"): return NvDevicePower.A100_FP16 * units else: raise TypeError(f'Unknown NVIDIA GPU device name {device_name}') def _profile_bandwidth(self): """ Profile the bidirectional communication bandwidth between CPU and GPU using data volumes ranging from 1KB to 1GB. """ print('profiling bandwidth ......') link_to_bandwidth = {} links = ['h2d', 'd2h'] for link in links: t_size = 1024 size_to_bandwidth = {} # from 1KB to 1GB for i in range(21): if link == 'h2d': src_tensor = torch.ones( int(t_size), dtype=torch.int8, pin_memory=True) dst_tensor = torch.ones( (int(t_size)), dtype=torch.int8, device='cuda') elif link == 'd2h': src_tensor = torch.ones( int(t_size), dtype=torch.int8, device='cuda') dst_tensor = torch.ones( (int(t_size)), dtype=torch.int8, pin_memory=True) def func(): dst_tensor.copy_(src_tensor) size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3) print(f'size: {t_size / 1024 ** 2:.3f} MB, ' f'{src_tensor.device.type}-to-{dst_tensor.device.type} ' f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s') t_size *= 2 link_to_bandwidth[link] = size_to_bandwidth return link_to_bandwidth class SynGreedySolver(Solver): def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None: super().__init__(region_list, memory_budget) self.best_ts: SynTrainingSimulator = None self._init_state() def _init_state(self): """ Initialize the solver state when without offloading. """ ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts.execute() self._update_state(ts) def _call_solver(self): """ Call the solver to search an efficient parameter offloading strategy for the linearized graph. The solver adopts greedy algorithm. Raises: NotImplementedError: Unable to find a solution for the given memory budget. """ print("search offloading strategy ......") while self.best_ts.peak_mem > self.memory_budget: offload_region = None best_ts = None max_profit = (0,) # search which region should be offloaded, # the last region does not need to be offloaded. for region in self.region_list[:-1]: if region.param_size and not region.need_offload: temp_ts, profit = self._try_to_offload(region) if self._compare_profit(profit, max_profit): offload_region = region max_profit = profit best_ts = temp_ts if offload_region is not None and best_ts is not None: offload_region.need_offload = True offload_region.is_syn = True self._update_state(best_ts) else: raise NotImplementedError( f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") def _call_solver_l2l(self): """ The layer-wise offload strategy. """ for region in self.region_list[:-1]: region.need_offload = True region.is_syn = True def _try_to_offload(self, offload_region: Region): # record previous information orig_need_offload = offload_region.need_offload assert not orig_need_offload offload_region.need_offload = True ts, profit = self._eval_one_choice(offload_region) # restore previous information offload_region.need_offload = orig_need_offload return ts, profit def _eval_one_choice(self, offload_region: Region): """ Evaluate the profit of a strategy choice. Args: offload_region (Region): the offload region of current choice. Returns: SynTrainingSimulator: the training simulator corresponding to the current strategy. tuple: contains memory saving and cost information of the current strategy. """ ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts.execute() extra_comm_cost = 2.0 * \ ts._get_communication_overhead('h2d', offload_region.param_size) # the shared region needs to be moved twice if offload_region.r_id < offload_region.shared_rid: extra_comm_cost *= 2.0 profit = self._compute_offload_profit( ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) return ts, profit class AsynGreedySolver(Solver): def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3): super().__init__(region_list, memory_budget) self.search_window_size = search_window_size # Records the prefetch execution location of the offloaded region self.region_to_region_map = {} self.best_ts: AsynTrainingSimulator = None self._init_state() def _init_state(self): """ Initialize the solver state when without offloading. """ ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts.execute() self._update_state(ts) print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB") def _call_solver(self): """ Call the solver to search an efficient parameter offloading strategy for the linearized graph. The solver adopts greedy algorithm. Raises: NotImplementedError: Unable to find a solution for the given memory budget. """ print("search for offloading strategy ......") # Records the prefetch execution location of the offloaded region region_to_region_map = {} while self.best_ts.peak_mem > self.memory_budget: region_to_offload = None max_offload_profit = (0,) best_offl_ts = None # search which region should be offloaded, # the last region does not need to be offloaded for region in self.region_list[:-1]: if region.param_size and not region.need_offload: max_prefetch_profit = (0,) best_pref_ts = None # search when to prefetch the region offloaded for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]: if host_region.bwd_prefetch_region is not None: continue temp_ts, profit = self._try_to_offload( host_region, region) if self._compare_profit(profit, max_prefetch_profit): region_to_region_map[region.r_id] = host_region max_prefetch_profit = profit best_pref_ts = temp_ts if profit[0] == float('inf'): break if self._compare_profit(max_prefetch_profit, max_offload_profit): region_to_offload = region max_offload_profit = max_prefetch_profit best_offl_ts = best_pref_ts if (region_to_offload is not None) and (best_offl_ts is not None): region_to_offload.need_offload = True if region_to_region_map[region_to_offload.r_id] == region_to_offload: region_to_offload.is_syn = True else: region_to_region_map[region_to_offload.r_id].bwd_prefetch_region = region_to_offload self.region_to_region_map[region_to_offload.r_id] = region_to_region_map[region_to_offload.r_id] self._update_state(best_offl_ts) elif self.region_to_region_map.__len__() > 0: self._repair_strategy() else: raise NotImplementedError( f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") region_to_region_map.clear() def _try_to_offload(self, host_region: Region, offload_region: Region): """ Attempts to offload the region and prefetch it in backward pass. """ # record previous information orig_prefetch = host_region.bwd_prefetch_region orig_is_syn = offload_region.is_syn orig_need_offload = offload_region.need_offload if host_region == offload_region: offload_region.is_syn = True else: host_region.bwd_prefetch_region = offload_region offload_region.need_offload = True ts, profit = self._eval_one_choice() # restore previous information host_region.bwd_prefetch_region = orig_prefetch offload_region.is_syn = orig_is_syn offload_region.need_offload = orig_need_offload return ts, profit def _try_convert_to_syn_upload(self, host_region: Region, offload_region: Region): """ Attempts to convert asynchronous prefetch into synchronous upload operations. """ # record previous information orig_prefetch = host_region.bwd_prefetch_region orig_is_syn = offload_region.is_syn assert orig_prefetch is not None and not orig_is_syn host_region.bwd_prefetch_region = None offload_region.is_syn = True ts, profit = self._eval_one_choice() # restore previous information host_region.bwd_prefetch_region = orig_prefetch offload_region.is_syn = orig_is_syn return ts, profit def _repair_strategy(self): """ Repair offload strategy. It attempts to convert asynchronous prefetch into synchronous upload operations and selects the best one. The repair process does not end until peak memory is reduced or there is no asynchronous prefetch operation. """ print("repair strategy ......") peak_mem_saving = 0 while len(self.region_to_region_map) and peak_mem_saving <= 0: max_profit = (0,) best_ts = None undo_host_region = None undo_offload_region = None for offload_region_id, host_region in self.region_to_region_map.items(): offload_region = self.region_list[offload_region_id] assert host_region.bwd_prefetch_region == offload_region assert offload_region.need_offload assert not offload_region.is_syn ts, profit = self._try_convert_to_syn_upload(host_region, offload_region) if self._compare_profit(profit, max_profit): undo_host_region = host_region undo_offload_region = offload_region max_profit = profit best_ts = ts if best_ts is None: raise NotImplementedError('repair error!') assert not undo_offload_region.is_syn undo_offload_region.is_syn = True undo_host_region.bwd_prefetch_region = None peak_mem_saving = self.best_ts.peak_mem - best_ts.peak_mem self._update_state(best_ts) self.region_to_region_map.pop(undo_offload_region.r_id) return best_ts def _eval_one_choice(self): """ Evaluate the profit of a strategy choice. Returns: AsynTrainingSimulator: the training simulator corresponding to the current strategy. tuple: contains memory saving and cost information of the current strategy. """ ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts.execute() extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0) profit = self._compute_offload_profit( ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) return ts, profit class SolverFactory: solvers: Dict[str, Type[Solver]] = { 'syn': SynGreedySolver, 'asyn': AsynGreedySolver } @staticmethod def create(solver_name: str) -> Type[Solver]: if solver_name not in SolverFactory.solvers: raise TypeError(f"Unknown parameter offload policy {solver_name}") return SolverFactory.solvers[solver_name] @staticmethod def get_solver_names(): return tuple(SolverFactory.solvers.keys())