diff --git a/colossalai/auto_parallel/offload/__init__.py b/colossalai/auto_parallel/offload/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py new file mode 100644 index 000000000..a79e5006e --- /dev/null +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -0,0 +1,177 @@ +from typing import Dict, Tuple +from enum import Enum +import torch +from torch.optim import Optimizer + +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.utils import get_current_device + +from .base_offload_module import BaseOffloadModule +from .region_manager import RegionManager +from .region import Region + + +class OptimState(Enum): + SCALED = 0 + UNSCALED = 1 + +class AMPOptimizer(ColossalaiOptimizer): + + """ + A wrapper for Optimizer. + Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py + + Args: + optimizer (Optimizer): An Optimizer instance. + module (BaseOffloadModule): A ``BaseOffloadModule`` instance. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + norm_type (float, optional): norm_type used for `clip_grad_norm`. + """ + + def __init__(self, + optimizer: Optimizer, + module: BaseOffloadModule, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + clipping_norm: float = 0.0, + norm_type: float = 2.0): + + super().__init__(optimizer) + + self.module = module + self.optim_state = OptimState.UNSCALED + self.clipping_flag = clipping_norm > 0.0 + self.max_norm = clipping_norm + + self.region_manager: RegionManager = self.module.region_manager + self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict() + self.param_to_region: Dict[torch.nn.Parameter, Region] = dict() + + self.fp32_to_fp16_params: Dict[torch.Tensor, torch.nn.Parameter] = dict() + + if self.clipping_flag: + assert norm_type == 2.0, "AMPOptimizer only supports L2 norm now" + + self.__init__optimizer() + + # Grad scaler + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + self._logger = get_dist_logger() + + def _set_grad_ptr(self): + for group in self.param_groups: + for fake_param in group['params']: + region = self.param_to_region[fake_param] + begin, end = self.param_to_range[fake_param] + + fake_param.data = region.cpu_grad[begin:end] + fake_param.grad = fake_param.data + fake_param.data = region.fp32_data[begin:end] + + def _update_fp16_params(self): + none_tensor = torch.empty([0]) + for group in self.param_groups: + for fake_param in group['params']: + assert fake_param.grad is None + fake_param.data = none_tensor + self.param_to_region[fake_param].cpu_grad = None + + def _check_overflow(self): + # clear previous overflow record + self._found_overflow.fill_(self.module.overflow_counter.item()) + return self._found_overflow.item() > 0 + + def _get_combined_scale(self): + loss_scale = 1 + + if self.optim_state == OptimState.SCALED: + loss_scale = self.loss_scale + self.optim_state = OptimState.UNSCALED + + combined_scale = loss_scale + + if combined_scale == 1: + return -1 + else: + return combined_scale + + @property + def loss_scale(self): + return self.grad_scaler.scale.item() + + def zero_grad(self, *args, **kwargs): + self.module.overflow_counter = torch.cuda.IntTensor([0]) + return self.optim.zero_grad(set_to_none=True) + + def step(self, *args, **kwargs): + # Copy gradients from model params to main params. + self._set_grad_ptr() + + found_inf = self._check_overflow() + if found_inf: + self.optim_state = OptimState.UNSCALED # no need to unscale grad + self.grad_scaler.update(found_inf) # update gradient scaler + self._logger.info(f'Found overflow. Skip step') + self.zero_grad() # reset all gradients + self._update_fp16_params() + return + + # get combined scale. combined scale = loss scale * clipping norm + # so that gradient = gradient / combined scale + combined_scale = self._get_combined_scale() + self.grad_scaler.update(found_inf) + + ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) + self.zero_grad() + self._update_fp16_params() + return ret + + def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): + raise NotImplementedError + + def backward(self, loss: torch.Tensor): + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self.module.backward(loss) + + def __init__optimizer(self): + + for group in self.optim.param_groups: + fake_params_list = list() + + for param in group['params']: + region = self.region_manager.get_region(param) + fake_param = torch.nn.Parameter(torch.empty([0])) + self.param_to_range[fake_param] = region.param_to_range[param] + self.param_to_region[fake_param] = region + fake_params_list.append(fake_param) + + # Reset existing state dict key to the new main param. + if param in self.optim.state: + self.optim.state[fake_param] = self.optim.state.pop(param) + + group['params'] = fake_params_list + + # Leverage state_dict() and load_state_dict() to + # recast preexisting per-param state tensors + self.optim.load_state_dict(self.optim.state_dict()) \ No newline at end of file diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py new file mode 100644 index 000000000..59cea4ece --- /dev/null +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -0,0 +1,109 @@ +from typing import Optional, Set +from functools import partial +import torch +import torch.nn as nn + +from colossalai.nn.parallel.data_parallel import _cast_float +from colossalai.gemini.tensor_utils import free_storage + +from .region_manager import RegionManager +from .util import GlobalRuntimeInfo + + +class BaseOffloadModule: + """ + BaseOffloadModule: A model wrapper for parameter offloading. + + Args: + model (nn.Module): model to apply offloading. + region_manager (RegionManager): a ``RegionManager`` instance. + is_sync (bool): synchronous mode or not. + """ + + def __init__(self, + model: nn.Module, + region_manager: RegionManager, + is_sync=True): + + self.model = model + self.region_manager = region_manager + self.grad_hook_list = [] + self.overflow_counter = torch.cuda.IntTensor([0]) + + self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream + + self._cast_buffers() + + def register_grad_hook(self): + for p in self.model.parameters(): + if p.requires_grad: + self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p))) + + def remove_grad_hook(self): + for hook in self.grad_hook_list: + hook.remove() + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def _pre_forward(self): + self.register_grad_hook() + for region in self.region_manager.region_list: + region.cpu_grad = None + + def forward(self, *args, **kwargs): + args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) + self.model.zero_grad(set_to_none=True) + self._pre_forward() + outputs = self.model(*args, **kwargs) + return outputs + + def backward(self, loss): + loss.backward() + self._post_backward() + + def _post_backward(self): + torch.cuda.synchronize() + self.remove_grad_hook() + + for p in self.model.parameters(): + p.grad = None + + GlobalRuntimeInfo.fwd_prefetch_event_map.clear() + GlobalRuntimeInfo.bwd_prefetch_event_map.clear() + + def grad_handle(self, p, grad): + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + with torch._C.DisableTorchFunction(): + region = self.region_manager.get_region(p) + region.copy_grad_to_region_slice(p, grad) + if region.can_release: + self.overflow_counter += region.has_inf_or_nan + master_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.grad_offload_stream): + GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream) + region.move_grad_to_cpu() + return empty_grad + + def _cast_buffers(self): + for buffer in self.model.buffers(): + buffer.data = buffer.cuda() + + def parameters(self, recurse: bool = True): + return self.model.parameters(recurse) + + def named_parameters(self, prefix: str = '', recurse: bool = True): + return self.model.named_parameters(prefix, recurse) + + def named_buffers(self, prefix: str = '', recurse: bool = True): + return self.model.named_buffers(prefix, recurse) + + def named_children(self): + return self.model.named_children() + + def named_modules(self, + memo: Optional[Set[torch.nn.Module]] = None, + prefix: str = '', + remove_duplicate: bool = True): + return self.model.named_modules(memo, prefix, remove_duplicate) diff --git a/colossalai/auto_parallel/offload/mem_optimize.py b/colossalai/auto_parallel/offload/mem_optimize.py new file mode 100644 index 000000000..02778696a --- /dev/null +++ b/colossalai/auto_parallel/offload/mem_optimize.py @@ -0,0 +1,49 @@ +from typing import Dict +import torch +import torch.fx +from torch.fx import GraphModule +from torch.utils._pytree import tree_map + +from colossalai.fx import ColoTracer, is_compatible_with_meta +from colossalai.fx.passes.meta_info_prop import MetaInfoProp + +from .region_manager import RegionManager +from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass +from .base_offload_module import BaseOffloadModule +from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo + +def memory_optimize(model: torch.nn.Module, + inps: Dict[str, torch.Tensor], + memory_budget: float = -1.0, + solver_name: str = 'asyn'): + + model = model.cpu().half() + tracer = ColoTracer() + assert is_compatible_with_meta() + wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x + meta_args = tree_map(wrap_fn, inps) + graph = tracer.trace(model, meta_args=meta_args) + gm = GraphModule(model, graph, model.__class__.__name__) + interp = MetaInfoProp(gm) + interp.propagate(*meta_args.values()) + + region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget) + region_manager._build_regions() + GlobalRuntimeInfo.region_list = region_manager.region_list + + act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2 + max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2 + total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2 + print( + f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}") + + if solver_name == 'syn': + gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list) + elif solver_name == 'asyn': + gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list) + else: + raise TypeError(f"Unknown solver name {solver_name}!") + + gm.recompile() + optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn') + return optimized_model diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py new file mode 100644 index 000000000..e6907cc4b --- /dev/null +++ b/colossalai/auto_parallel/offload/region.py @@ -0,0 +1,144 @@ +from typing import List, Dict, Tuple +import torch +from torch.fx import Node +from colossalai.gemini.tensor_utils import alloc_storage, free_storage + +class Region: + """ + Region: A container owning a piece of contiguous nodes in the DNN computing graph. + + Args: + r_id (int): the index of the region in the computing graph. + """ + + def __init__(self, r_id: int = 0) -> None: + self.r_id: int = r_id + self.fp16_params: List[torch.nn.Parameter] = [] + self.param_size: int = 0 + self.shared_rid: int = self.r_id + + self.param_num: int = 0 + self.grad_num: int = 0 + self.fp16_data = None + self.fp32_data = None + self.cpu_grad = None + self.temp_fp32_data = None + self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict() + + self.need_offload: bool = False + self.is_syn: bool = False + self.nodes: List[Node] = [] + self.fwd_prefetch_region = None + self.bwd_prefetch_region = None + + self.in_mem_pool_flag: bool = False + + @property + def can_release(self) -> bool: + """ + Check if the region can be released. + """ + return self.grad_num == self.param_num + + @property + def has_inf_or_nan(self) -> bool: + """ + Check if the grad of the region has inf or nan values on CUDA. + """ + return torch.isinf(self.fp16_data).any() | torch.isnan(self.fp16_data).any() + + def init_param_data(self, pre_alloc_tensor: torch.Tensor = None): + """ + Map the parameters in the region to a contiguous memory space. + """ + + self.fp16_data = torch.zeros( + self.param_num, dtype=torch.half, device='cuda') + offset = 0 + for param in self.fp16_params: + param.data = param.data.cuda() + p_num = param.data.numel() + self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) + param.data = self.fp16_data[offset:offset + + p_num].view(param.data.shape) + self.param_to_range[param] = (offset, offset + p_num) + offset += p_num + + self.fp32_data = self.fp16_data.float().cpu().pin_memory() + free_storage(self.fp16_data) + if self.in_mem_pool_flag and pre_alloc_tensor is not None: + self.fp16_data = pre_alloc_tensor + + def move_param_to_cuda(self): + """ + Move parameters from CPU to GPU. + It first moves float32 parameters to GPU and + then transforms float32 parameters to half-precision on the GPU. + The reason is that the performance of precision conversion on the CPU + is much slower than the data transfer overhead. + """ + + self.temp_fp32_data.copy_(self.fp32_data, non_blocking=True) + self.temp_fp32_data.record_stream(torch.cuda.current_stream()) + if not self.in_mem_pool_flag: + alloc_storage(self.fp16_data) + self.fp16_data[:self.param_num].copy_(self.temp_fp32_data) + self.fp16_data.record_stream(torch.cuda.current_stream()) + + self.__update_params_ptr() + + def move_grad_to_cpu(self): + """ + Move gradients from GPU to CPU. + """ + + self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True) + self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True) + self.fp16_data.record_stream(torch.cuda.current_stream()) + if not self.in_mem_pool_flag: + self.free_cuda_data() + + self.grad_num = 0 + + def free_cuda_data(self): + free_storage(self.fp16_data) + + # torch.cuda.empty_cache() + + def copy_grad_to_region_slice(self, param: torch.nn.Parameter, data_slice: torch.Tensor) -> None: + """ + Copy data slice to the memory space indexed by the input tensor in the region. + + Args: + param (torch.nn.Parameter): the param used to retrive meta information + data_slice (torch.Tensor): the tensor to be copied to the region + """ + + begin, end = self.param_to_range[param] + self.fp16_data[begin:end].copy_(data_slice.data.flatten()) + param.data = self.fp16_data[begin:end].view(param.data.shape) + + self.grad_num += data_slice.numel() + + def split(self, cut_node_idx: int, cut_param_idx: int): + """ + Split the region into two and return the latter. + """ + new_reg = Region(r_id=self.r_id + 1) + new_reg.nodes = self.nodes[cut_node_idx:] + new_reg.fp16_params = self.fp16_params[cut_param_idx:] + for p in new_reg.fp16_params: + new_reg.param_size += p.data.numel() * p.data.element_size() + new_reg.param_num += p.data.numel() + + self.nodes = self.nodes[:cut_node_idx] + self.fp16_params = self.fp16_params[:cut_param_idx] + self.param_size -= new_reg.param_size + self.param_num -= new_reg.param_num + + return new_reg + + def __update_params_ptr(self) -> None: + for param in self.fp16_params: + begin, end = self.param_to_range[param] + param.data = self.fp16_data[begin:end].view(param.data.shape) \ No newline at end of file diff --git a/colossalai/auto_parallel/offload/region_manager.py b/colossalai/auto_parallel/offload/region_manager.py new file mode 100644 index 000000000..30bfaf00d --- /dev/null +++ b/colossalai/auto_parallel/offload/region_manager.py @@ -0,0 +1,526 @@ +from typing import List, Any, Dict, Tuple +import torch +from torch.fx import Graph, Node + +from .solver import SolverFactory +from .training_simulator import TrainingSimulator +from .region import Region +from .util import NodeInfo + + +class RegionManager: + """ + RegionManager is used to construct and manage the offload plan for the model execution. + + Args: + graph (Graph): a Graph object used for analysis and strategy generation. + solver_name (str): a solver name which specifies the preferences for plan searching. + memory_budget (float): the given memory budget. + cnode (List[str], optional): Common node List, should be the subset of input. + """ + + def __init__(self, + graph: Graph, + solver_name: str = 'asyn', + memory_budget: float = -1.0, + cnode: List[str] = None): + + self.graph = graph + assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + self.root_module = self.graph.owning_module + self.nodes = list(graph.nodes) + self.cnode = cnode + self.only_param_ops = [] + self.param_region_map: Dict[torch.nn.Parameter, Region] = dict() + self.shared_region_pairs: List[Tuple[Region, Region]] = list() + self.region_list: List[Region] = list() + self.rid_in_pool: List[int] = list() + self.mem_block_size: int = 0 + self.memory_budget = memory_budget + + self.solver_name = solver_name + self.require_pool: bool = solver_name == 'asyn' + + self.reg_to_block: Dict[int, int] = dict() + + def _build_regions(self): + """ + 1. Pre-processing, mainly contains linearized computing graph and + merge smaller regions into larger ones. + 2. Construct a solver to search for an efficient offload strategy. + 3. Post-processing, mainly contains early region placement if using asynchronous mode, + and initialize region data. + """ + + self._pre_process() + + solver_cls = SolverFactory.create(self.solver_name) + solver = solver_cls(self.region_list, self.memory_budget) + solver._call_solver() + + self._post_process(solver.best_ts) + + def _pre_process(self): + + init_region_list = self._linearize_graph() + + if len(self.shared_region_pairs) > 1: + raise NotImplementedError( + 'The current version only considers at most one pair of parameter sharing.') + + elif len(self.shared_region_pairs) == 1: + shared_regs = self.shared_region_pairs[0] + assert shared_regs[0].shared_rid == shared_regs[1].r_id \ + and shared_regs[1].shared_rid == shared_regs[0].r_id + fst_id = shared_regs[0].r_id + lst_id = shared_regs[1].r_id + regs_left_out = init_region_list[:fst_id + 1] + regs_right_out = init_region_list[lst_id:] + hold_regs = init_region_list[fst_id + 1:lst_id] + else: + regs_left_out = [] + regs_right_out = [] + hold_regs = init_region_list + + self.mem_block_size = self._search_block_size(hold_regs) + hold_regs = self._merge_small_regions(hold_regs) + + if self.require_pool: + for reg in hold_regs: + reg.in_mem_pool_flag = True + self.rid_in_pool.append(reg.r_id) + + self.region_list.extend(regs_left_out) + self.region_list.extend(hold_regs) + + for reg in regs_right_out: + reg.r_id = self.region_list[-1].r_id + 1 + self.region_list[reg.shared_rid].shared_rid = reg.r_id + self.region_list.append(reg) + + self._process_shared_region() + + self.max_param_num = max([reg.param_num for reg in self.region_list]) + self.memory_budget -= self.max_param_num * torch.tensor([], dtype=torch.float32).element_size() + + def _post_process(self, ts: TrainingSimulator = None): + if self.require_pool: + self._early_region_placement(ts) + self._init_region_data() + + def _early_region_placement(self, ts: TrainingSimulator): + """ + Implemented the early region placement strategy to avoid GPU memory fragmentation. + It maps all region data into a contiguous memory space and + reuses the same memory space for regions that do not coexist. + + Args: + ts (TrainingSimulator): the best training simulator, which records region execution flow. + + Raises: + NotImplementedError: due to the naive implementation, + it may not find a suitable region placement strategy for the given execution flow. + """ + + reg_flow = torch.cat( + [ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) + mem_block_num = torch.max( + torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) + coexist_matrix = torch.logical_or( + ts.fwd_reg_flow, ts.bwd_reg_flow) + + block_to_regs = {} + for block_idx in range(mem_block_num): + block_to_regs[block_idx] = [] + for reg in self.region_list: + if reg.r_id in self.rid_in_pool: + cur_reg_appears = coexist_matrix[:, reg.r_id] + cur_reg_coexists = torch.sum( + coexist_matrix[cur_reg_appears], dim=0).bool() + for block_idx in range(mem_block_num): + if not any(cur_reg_coexists[block_to_regs[block_idx]]): + block_to_regs[block_idx].append(reg.r_id) + self.reg_to_block[reg.r_id] = block_idx + break + + if reg.r_id not in self.reg_to_block: + raise NotImplementedError( + f'can not find a block from the memory pool to store parameters of the region') + self.memory_pool = torch.chunk(torch.zeros(int( + mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num)) + + def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: + """ + Merge smaller regions into larger ones for better bandwidth utilization and easier management. + It is inspired by Gemini. + + Args: + orig_reg_list (List[Region]): original region list. + + Returns: + List[Region]: region list after merging. + """ + + r_id = orig_reg_list[0].r_id + region = Region(r_id=r_id) + region_list = [region] + + for orig_reg in orig_reg_list: + if region_list[-1].param_size + orig_reg.param_size > self.mem_block_size: + r_id += 1 + region = Region(r_id=r_id) + region_list.append(region) + region.param_size += orig_reg.param_size + region.param_num += orig_reg.param_num + region.nodes.extend(orig_reg.nodes) + region.fp16_params.extend(orig_reg.fp16_params) + self.__update_param_region_map(orig_reg.fp16_params, region) + + return region_list + + def _search_block_size(self, + region_list: List[Region], + search_interval_byte: int = 1024, + search_range_byte: int = 128 * 1024 ** 2) -> int: + """ + Search for a suitable memory block size. + + Args: + region_list (List[Region]): region list. + search_interval_byte (int): searching interval in byte. + search_range_byte (int): searching range in byte. + + Returns: + int: the best memory block size. + """ + + def _get_wasted_mem(size_list: List[int], blk_size: int): + """ + Get wasted byte for a certain block size. + """ + acc_wasted = 0 + left = 0 + for s in size_list: + if left + s > blk_size: + acc_wasted += blk_size - left + left = s + left += s + acc_wasted += blk_size - left + return acc_wasted + + param_size_list = [ + region.param_size for region in region_list if region.r_id == region.shared_rid] + + start_size = max(param_size_list) + min_mem_waste = float('+inf') + best_block_size = start_size + + for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): + temp_waste = 0 + temp_waste += _get_wasted_mem(param_size_list, block_size) + if temp_waste < min_mem_waste: + min_mem_waste = temp_waste + best_block_size = block_size + + return best_block_size + + def _init_region_data(self): + """ + Initialize region data, which maps the parameters in the region to a contiguous memory space. + """ + + self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32) + + for region in self.region_list: + pre_alloc_tensor = None + if self.require_pool and region.r_id in self.rid_in_pool: + block_idx = self.reg_to_block[region.r_id] + pre_alloc_tensor = self.memory_pool[block_idx] + + if region.r_id <= region.shared_rid: + region.init_param_data(pre_alloc_tensor) + else: + shared_region = self.region_list[region.shared_rid] + region.fp16_data = shared_region.fp16_data + region.fp32_data = shared_region.fp32_data + region.param_to_range = shared_region.param_to_range + region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach( + ) + + torch.cuda.empty_cache() + + def _process_shared_region(self): + """ + Special processing for the shared region, which uses GPT2 and Bert case as a priori knowledge. + """ + + if len(self.shared_region_pairs): + assert len(self.shared_region_pairs) <= 1 + former_reg, latter_reg = self.shared_region_pairs[0] + assert latter_reg.param_num >= former_reg.param_num + embedding_node = former_reg.nodes[-1] + assert embedding_node.op == 'call_module' and isinstance( + self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding) + if latter_reg.param_num > former_reg.param_num: + for idx, n in enumerate(latter_reg.nodes): + if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target), + torch.nn.Linear)) or \ + (n.op == 'call_function' and n.target is torch.nn.functional.linear): + cut_node_idx = idx + 1 + break + assert len(latter_reg.fp16_params) == 2 + new_reg = latter_reg.split(cut_node_idx, 1) + for p in new_reg.fp16_params: + self.param_region_map[p] = new_reg + self.region_list.insert(new_reg.r_id, new_reg) + for reg in self.region_list[new_reg.r_id + 1:]: + reg.r_id += 1 + latter_reg.shared_rid = former_reg.r_id + former_reg.shared_rid = latter_reg.r_id + + def _linearize_graph(self) -> List[Region]: + """Linearizing the graph + + Args: + graph (Graph): The computing graph to be optimized. + + Returns: + List[Region]: each region contains the actual 'node' in linearized manner. + + Remarks: + Do merge the inplace ops and shape-consistency ops into the previous node. + """ + + # List of target name that could be seen as common node + common_ops = ["getattr", "getitem", "size"] + + def _is_cop(target: Any) -> bool: + """Check if an op could be seen as common node + + Args: + target (Any): node target + + Returns: + bool + """ + + if isinstance(target, str): + return target in common_ops + else: + return target.__name__ in common_ops + + def _is_act(data: Any) -> bool: + """Check if an op could be seen as parameter computation start + + Args: + data (Any): meta_data + + Returns: + bool + """ + + label = False + if isinstance(data, torch.Tensor): + return True + elif isinstance(data, (tuple, list)): + for d in data: + label = label or _is_act(d) + return label + + def _maybe_param_comp_start() -> bool: + """Check if an op could be seen as parameter computation start + + Args: + n (Node): node + + Returns: + bool + """ + + label = False + if n.op == "get_attr": + label = True + elif n.op == "call_module": + target = n.target + submod = self.root_module.get_submodule(target) + if ( + len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0 + ): + label = True + + return label and not sum([v for _, v in param_op_deps.items()]) + + def _is_param_comp_end() -> bool: + """Check if an op could be seen as parameter computation end + + Args: + n (Node): node + + Returns: + bool + """ + + def _is_inplace(n: Node): + """Get the inplace argument from ``torch.fx.Node`` + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule( + n.target), "inplace", False) + return inplace + + label = False + + if n.op == "call_module": + target = n.target + submod = self.root_module.get_submodule(target) + if ( + len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0 + ): + label = True + + elif n.op == "call_function": + label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any( + map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)) + + return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users)) + + def _exception_node_handling(): + # TODO meta info prop bug + if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2: + n.meta['fwd_out'] = [] + + # make sure that item in cnode is valid + if self.cnode: + for name in self.cnode: + try: + assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ + f"Common node {name} is not an input of the model." + except StopIteration: + raise ValueError(f"Common node name {name} not in graph.") + else: + self.cnode = [] + + node_id = 0 + region_id = 0 + + param_op_deps = {} + + deps = {} + region_list = [] + region = Region(r_id=region_id) + + act_n = None + + for n in self.graph.nodes: + if n.op != "placeholder" and n.op != "output": + for n_par in n.all_input_nodes: + if n_par.op != "placeholder" and n_par.name not in self.cnode: + deps[n_par] -= 1 + if n_par.op != "placeholder" and n_par.name in self.only_param_ops: + param_op_deps[n_par] -= 1 + + if act_n in region.nodes and _maybe_param_comp_start(): + ns = [] + border_n_idx = region.nodes.index(act_n) + if border_n_idx < len(region.nodes): + ns = region.nodes[border_n_idx + 1:] + region.nodes = region.nodes[:border_n_idx + 1] + region_list.append(region) + region_id += 1 + region = Region(r_id=region_id) + region.nodes = ns + + _exception_node_handling() + region.nodes.append(n) + self._set_node_and_region_info(node_id, n, region) + node_id += 1 + + # if the node could free all dependencies in graph + # we could begin a new region + if _is_param_comp_end(): + region_list.append(region) + region_id += 1 + region = Region(r_id=region_id) + + # propagate common node attr if possible + if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode + ]) or _is_cop(n.target): + self.cnode.append(n.name) + else: + deps[n] = len( + [user for user in n.users if user.op != "output"]) + + # propagate param node attr if possible + if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops + ]) or n.op == "get_attr": + self.only_param_ops.append(n.name) + param_op_deps[n] = len( + [user for user in n.users if user.op != "output"]) + + # record last activation node + if _is_act(n._meta_data): + act_n = n + + if len(region.nodes): + region_list.append(region) + + return region_list + + def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): + + cur_n.node_info = NodeInfo(node_id) + + if cur_n.op == 'call_module': + target = cur_n.target + submod = self.root_module.get_submodule(target) + for p in list(submod.parameters(recurse=False)): + + if p in self.param_region_map: + cur_reg.shared_rid = self.param_region_map[p].r_id + self.param_region_map[p].shared_rid = cur_reg.r_id + self.shared_region_pairs.append( + (self.param_region_map[p], cur_reg)) + else: + self.param_region_map[p] = cur_reg + + cur_reg.fp16_params.append(p) + cur_reg.param_num += p.data.numel() + cur_reg.param_size += p.data.numel() * p.data.element_size() + + elif cur_n.op == "get_attr": + attr_itr = self.root_module + atoms = cur_n.target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + + if isinstance(attr_itr, torch.nn.Parameter): + + if attr_itr in self.param_region_map: + cur_reg.shared_rid = self.param_region_map[attr_itr].r_id + self.param_region_map[attr_itr].shared_rid = cur_reg.r_id + self.shared_region_pairs.append( + (self.param_region_map[attr_itr], cur_reg)) + else: + self.param_region_map[attr_itr] = cur_reg + + cur_reg.fp16_params.append(attr_itr) + cur_reg.param_num += attr_itr.data.numel() + cur_reg.param_size += attr_itr.data.numel() * attr_itr.data.element_size() + + def get_region(self, param: torch.nn.Parameter) -> Region: + """ + Return the region owning the parameter. + + Args: + param (torch.nn.Parameter): a torch parameter object + """ + return self.param_region_map[param] + + def __update_param_region_map(self, params: List[torch.nn.Parameter], region: Region): + for p in params: + self.param_region_map[p] = region diff --git a/colossalai/auto_parallel/offload/runtime.py b/colossalai/auto_parallel/offload/runtime.py new file mode 100644 index 000000000..91c7945bd --- /dev/null +++ b/colossalai/auto_parallel/offload/runtime.py @@ -0,0 +1,253 @@ +from typing import List +import torch +from torch.fx.node import Node + +from .region import Region +from .util import GlobalRuntimeInfo, requires_upload_p_in_fwd + + +class SynPreFwdPostBwdOP(torch.autograd.Function): + """ + A customized prefetch and offload operation. + + Args: + input_: input tensor. + fwd_info: information dict, which contains region indices + that need to be uploaded or freed during forward pass. + bwd_info: information dict, which contains region indices + that need to be uploaded during backward pass. + """ + + @staticmethod + def forward(ctx, input_, fwd_info, bwd_info): + ctx.bwd_info = bwd_info + d2h_rid = fwd_info.get('d2h_rid', None) + if d2h_rid is not None: + free_region = GlobalRuntimeInfo.region_list[d2h_rid] + assert isinstance(free_region, Region) + free_region.free_cuda_data() + + h2d_rid = fwd_info.get('h2d_rid', None) + if h2d_rid is not None: + h2d_region = GlobalRuntimeInfo.region_list[h2d_rid] + assert isinstance(h2d_region, Region) + h2d_region.move_param_to_cuda() + + return input_ + + @staticmethod + def backward(ctx, grad_output): + + h2d_rid = ctx.bwd_info.get('h2d_rid', None) + if h2d_rid is not None: + pref_region = GlobalRuntimeInfo.region_list[h2d_rid] + assert isinstance(pref_region, Region) + pref_region.move_param_to_cuda() + + return grad_output, None, None + + +class AsynPreFwdPostBwdOP(torch.autograd.Function): + """ + A customized prefetch and offload operation. + + Args: + input_: input tensor. + fwd_info: information dict, which contains region indices + that need to be prefetched, waited, or freed during forward pass. + bwd_info: information dict, which contains region indices + that need to be prefetched or waited during backward pass. + """ + + @staticmethod + def forward(ctx, input_, fwd_info, bwd_info): + ctx.bwd_info = bwd_info + + sync_rid = fwd_info.get('sync_rid', None) + if sync_rid is not None: + prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get( + sync_rid, None) + if prefetch_event: + prefetch_event.wait() + + h2d_rid = fwd_info.get('h2d_rid', None) + if h2d_rid is not None: + pref_region = GlobalRuntimeInfo.region_list[h2d_rid] + assert isinstance(pref_region, Region) + master_stream = torch.cuda.current_stream() + with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream): + GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream) + pref_region.move_param_to_cuda() + + prefetch_event = torch.cuda.Event() + prefetch_event.record(GlobalRuntimeInfo.h2d_stream) + GlobalRuntimeInfo.fwd_prefetch_event_map[h2d_rid] = prefetch_event + + return input_ + + @staticmethod + def backward(ctx, grad_output): + + sync_rid = ctx.bwd_info.get('sync_rid', None) + if sync_rid is not None: + wait_region = GlobalRuntimeInfo.region_list[sync_rid] + assert isinstance(wait_region, Region) + prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get( + sync_rid, None) + if prefetch_event: + prefetch_event.wait() + else: + wait_region.move_param_to_cuda() + + h2d_rid = ctx.bwd_info.get('h2d_rid', None) + if h2d_rid is not None: + pref_region = GlobalRuntimeInfo.region_list[h2d_rid] + assert isinstance(pref_region, Region) + master_stream = torch.cuda.current_stream() + with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream): + GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream) + pref_region.move_param_to_cuda() + + prefetch_event = torch.cuda.Event() + prefetch_event.record(GlobalRuntimeInfo.h2d_stream) + GlobalRuntimeInfo.bwd_prefetch_event_map[h2d_rid] = prefetch_event + return grad_output, None, None + + +def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): + ''' + Convert Upload and Offload operation into runtime action. + + Argument: + tensor(torch.Tensor): input tensor. + fwd_info(dict): information dict, which contains region indices + that need to be uploaded, or freed during forward pass. + bwd_info(dict): information dict, which contains region indices + that need to be uploaded during backward pass. + ''' + with torch._C.DisableTorchFunction(): + ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) + return ret + +def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): + ''' + Convert Prefetch and Offload operation into runtime action. + + Argument: + tensor(torch.Tensor): input tensor. + fwd_info(dict): information dict, which contains region indices + that need to be prefetched, waited, or freed during forward pass. + bwd_info(dict): information dict, which contains region indices + that need to be prefetched or waited during backward pass. + ''' + with torch._C.DisableTorchFunction(): + ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) + return ret + + +def replace_node_users(orig_node: Node, inserted_node: Node, rep_user_nodes: List[Node] = None): + user_list = list(orig_node.users.keys()) + if rep_user_nodes is not None: + user_list = rep_user_nodes + for user in user_list: + if user == inserted_node: + continue + new_args = list(user.args) + new_kwargs = dict(user.kwargs) + # the origin node may be a positional argument or key word argument of user node + if orig_node in new_args: + # substitute the origin node with offload_apply_node + new_args[new_args.index(orig_node)] = inserted_node + user.args = tuple(new_args) + elif str(orig_node) in new_kwargs: + # substitute the origin node with offload_apply_node + new_kwargs[str(orig_node)] = inserted_node + user.kwargs = new_kwargs + + +def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]): + """ + This pass is used to add the synchronous upload and offload spec apply node to the origin graph. + """ + mod_graph = gm.graph + last_inp_node = tuple(mod_graph.nodes)[0] + + for r_idx, region in enumerate(region_list): + # forward upload + fwd_info = {} + if requires_upload_p_in_fwd(region_list[region.shared_rid]): + fwd_info['h2d_rid'] = region.r_id + + # forward offload + if r_idx > 0 and region_list[r_idx - 1].need_offload: + fwd_info['d2h_rid'] = r_idx - 1 + + bwd_info = {} + # backward upload + if r_idx > 0 and region_list[r_idx - 1].need_offload: + bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id + + if fwd_info or bwd_info: + with mod_graph.inserting_after(last_inp_node): + new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action, + args=(last_inp_node, fwd_info, bwd_info)) + replace_node_users(last_inp_node, new_node) + + last_inp_node = region.nodes[-1] + + return gm + + +def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]): + """ + This pass is used to add the asynchronous prefetch and offload spec apply node to the origin graph. + """ + mod_graph = gm.graph + + # upload parameters of the first region + last_inp_node = tuple(mod_graph.nodes)[0] + first_region_with_p = [ + region for region in region_list if region.param_size][0] + fwd_info = {"h2d_rid": first_region_with_p.r_id} + with mod_graph.inserting_after(last_inp_node): + upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action, + args=(last_inp_node, fwd_info, {})) + replace_node_users(last_inp_node, upload_apply_node) + last_inp_node = upload_apply_node + + for r_idx, region in enumerate(region_list): + # forward prefetch + fwd_info = {} + if region.param_size: + fwd_info['sync_rid'] = region.r_id + fwd_prefetch_region = region.fwd_prefetch_region + if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]): + fwd_info['h2d_rid'] = fwd_prefetch_region.r_id + + # forward offload + if r_idx > 0 and region_list[r_idx-1].need_offload: + fwd_info['d2h_rid'] = r_idx - 1 + + bwd_info = {} + # backward prefetch + if r_idx > 0 and region_list[r_idx-1].need_offload: + bwd_info['sync_rid'] = r_idx - 1 + if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region: + bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id + + if fwd_info or bwd_info: + with mod_graph.inserting_after(last_inp_node): + new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action, + args=(last_inp_node, fwd_info, bwd_info)) + replace_node_users(last_inp_node, new_node) + + last_inp_node = region.nodes[-1] + + if region.bwd_prefetch_region: + bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id} + with mod_graph.inserting_after(last_inp_node): + new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action, + args=(last_inp_node, {}, bwd_info)) + replace_node_users(last_inp_node, new_node) + # gm.graph.print_tabular() + return gm diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py new file mode 100644 index 000000000..161f7ff86 --- /dev/null +++ b/colossalai/auto_parallel/offload/solver.py @@ -0,0 +1,523 @@ +import time +from typing import List, Dict, Type +from abc import ABC, abstractmethod + +NOT_NVML = False +try: + from pynvml import * +except: + NOT_NVML = True + +import torch +from torch.fx.node import Node +from colossalai.utils.cuda import get_current_device + +from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator +from .region import Region +from .util import NodeInfo, NvDevicePower + + +def benchmark_func(func, number=1, repeat=1, warmup=3): + """ + benchmark data transfer cost. + """ + + for i in range(warmup): + func() + + costs = [] + + for i in range(repeat): + torch.cuda.synchronize() + begin = time.time() + for i in range(number): + func() + torch.cuda.synchronize() + costs.append((time.time() - begin) / number) + + return sum(costs) / len(costs) + + +class Solver(ABC): + """ + The parameter offload solver. + + Args: + region_list (List[Region]): represents the linearized DNN computing graph. + memory_budget (float): the given memory budget. + error_factor (float): the error factor. + It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time. + """ + + def __init__(self, + region_list: List[Region], + memory_budget: float = -1.0, + error_factor: float = 0.95) -> None: + + self.region_list = region_list + + self.error_factor: float = error_factor + if memory_budget > 0: + self.memory_budget = memory_budget * self.error_factor + else: + self.memory_budget = torch.cuda.get_device_properties( + get_current_device()).total_memory * self.error_factor + + self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() + self.comp_power: float = self._extract_computing_power() + + @abstractmethod + def _call_solver(self): + raise NotImplementedError + + @abstractmethod + def _try_to_offload(self, *args): + raise NotImplementedError + + @abstractmethod + def _eval_one_choice(self, *args): + raise NotImplementedError + + def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: float, extra_cost: float): + """ + Compute the profits of the offload strategies, + which packages the memory savings information for subsequent comparisons. + + Args: + total_mem_saving (float): the total memory saving of the offload strategy. + peak_mem_saving (float): the peak memory saving of the offload strategy. + extra_cost (float): extra data transfer cost. + + Returns: + tuple: profit information, the first term represents memory savings per unit of time. + """ + + if extra_cost == 0: + # means data transfer overhead can be completely overlapped + return (float('inf'), total_mem_saving, peak_mem_saving) + return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving) + + def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool: + """ + Compare the profits of the two offload strategies using the dictionary order algorithm. + + Args: + profit_a (tuple): the profit of a offload strategy. + profit_b (tuple): the profit of another offload strategy. + + Returns: + bool: whether profit_a is greater than profit_b. + """ + + for val1, val2 in zip(profit_a, profit_b): + if val1 != val2: + return val1 > val2 + return False + + def _update_state(self, best_ts: TrainingSimulator): + """ + Update the solver state. + """ + + self.best_ts = best_ts + self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem) + + def _update_node_mem_info(self, + fwd_mem_info: Dict[Node, float], + bwd_mem_info: Dict[Node, float]): + """ + Update the runtime memory information of the node. + + Args: + fwd_mem_info (Dict[Node, float]): the runtime memory of each node in forward pass. + bwd_mem_info (Dict[Node, float]): the runtime memory of each node in backward pass. + """ + + for node, mem in fwd_mem_info.items(): + assert hasattr(node, 'node_info') and isinstance( + node.node_info, NodeInfo) + node.node_info.runtime_fwd_mem = mem + for node, mem in bwd_mem_info.items(): + assert hasattr(node, 'node_info') and isinstance( + node.node_info, NodeInfo) + node.node_info.runtime_bwd_mem = mem + + def _extract_computing_power(self): + """ + return the FP16 computing performance of the current NVIDIA GPU. + + Raises: + TypeError: Unknown NVIDIA GPU device. + """ + + nvmlInit() + handle = nvmlDeviceGetHandleByIndex(0) + device_name = nvmlDeviceGetName(handle) + units = 1e12 + + if device_name.__contains__("RTX 3080"): + return NvDevicePower.RTX3080_FP16 * units + elif device_name.__contains__("RTX 3090"): + return NvDevicePower.RTX3090_FP16 * units + elif device_name.__contains__('V100'): + return NvDevicePower.V100_FP16 * units + elif device_name.__contains__("A100"): + return NvDevicePower.A100_FP16 * units + else: + raise TypeError(f'Unknown NVIDIA GPU device name {device_name}') + + def _profile_bandwidth(self): + """ + Profile the bidirectional communication bandwidth between CPU and GPU + using data volumes ranging from 1KB to 1GB. + """ + + print('profiling bandwidth ......') + link_to_bandwidth = {} + links = ['h2d', 'd2h'] + + for link in links: + t_size = 1024 + size_to_bandwidth = {} + + # from 1KB to 1GB + for i in range(21): + if link == 'h2d': + src_tensor = torch.ones( + int(t_size), dtype=torch.int8, pin_memory=True) + dst_tensor = torch.ones( + (int(t_size)), dtype=torch.int8, device='cuda') + elif link == 'd2h': + src_tensor = torch.ones( + int(t_size), dtype=torch.int8, device='cuda') + dst_tensor = torch.ones( + (int(t_size)), dtype=torch.int8, pin_memory=True) + + def func(): + dst_tensor.copy_(src_tensor) + + size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3) + print(f'size: {t_size / 1024 ** 2:.3f} MB, ' + f'{src_tensor.device.type}-to-{dst_tensor.device.type} ' + f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s') + + t_size *= 2 + + link_to_bandwidth[link] = size_to_bandwidth + return link_to_bandwidth + + +class SynGreedySolver(Solver): + + def __init__(self, + region_list: List[Region], + memory_budget: float = -1.0) -> None: + super().__init__(region_list, memory_budget) + + self.best_ts: SynTrainingSimulator = None + self._init_state() + + def _init_state(self): + """ + Initialize the solver state when without offloading. + """ + + ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) + ts.execute() + self._update_state(ts) + + def _call_solver(self): + """ + Call the solver to search an efficient parameter offloading strategy for the linearized graph. + The solver adopts greedy algorithm. + + Raises: + NotImplementedError: Unable to find a solution for the given memory budget. + """ + + print("search offloading strategy ......") + while self.best_ts.peak_mem > self.memory_budget: + offload_region = None + best_ts = None + max_profit = (0,) + + # search which region should be offloaded, + # the last region does not need to be offloaded. + for region in self.region_list[:-1]: + if region.param_size and not region.need_offload: + temp_ts, profit = self._try_to_offload(region) + if self._compare_profit(profit, max_profit): + offload_region = region + max_profit = profit + best_ts = temp_ts + + if offload_region is not None and best_ts is not None: + offload_region.need_offload = True + offload_region.is_syn = True + self._update_state(best_ts) + else: + raise NotImplementedError( + f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " + f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") + + def _call_solver_l2l(self): + """ + The layer-wise offload strategy. + """ + + for region in self.region_list[:-1]: + region.need_offload = True + region.is_syn = True + + def _try_to_offload(self, offload_region: Region): + + # record previous information + orig_need_offload = offload_region.need_offload + assert not orig_need_offload + offload_region.need_offload = True + + ts, profit = self._eval_one_choice(offload_region) + + # restore previous information + offload_region.need_offload = orig_need_offload + return ts, profit + + def _eval_one_choice(self, offload_region: Region): + """ + Evaluate the profit of a strategy choice. + + Args: + offload_region (Region): the offload region of current choice. + + Returns: + SynTrainingSimulator: the training simulator corresponding to the current strategy. + tuple: contains memory saving and cost information of the current strategy. + """ + + ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) + ts.execute() + + extra_comm_cost = 2.0 * \ + ts._get_communication_overhead('h2d', offload_region.param_size) + # the shared region needs to be moved twice + if offload_region.r_id < offload_region.shared_rid: + extra_comm_cost *= 2.0 + profit = self._compute_offload_profit( + ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + + return ts, profit + + +class AsynGreedySolver(Solver): + + def __init__(self, + region_list: List[Region], + memory_budget: float = -1.0, + search_window_size: int = 3): + super().__init__(region_list, memory_budget) + + self.search_window_size = search_window_size + # Records the prefetch execution location of the offloaded region + self.region_to_region_map = {} + self.best_ts: AsynTrainingSimulator = None + + self._init_state() + + def _init_state(self): + """ + Initialize the solver state when without offloading. + """ + + ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) + ts.execute() + self._update_state(ts) + print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB") + + def _call_solver(self): + """ + Call the solver to search an efficient parameter offloading strategy for the linearized graph. + The solver adopts greedy algorithm. + + Raises: + NotImplementedError: Unable to find a solution for the given memory budget. + """ + + print("search for offloading strategy ......") + # Records the prefetch execution location of the offloaded region + region_to_region_map = {} + while self.best_ts.peak_mem > self.memory_budget: + region_to_offload = None + max_offload_profit = (0,) + best_offl_ts = None + + # search which region should be offloaded, + # the last region does not need to be offloaded + for region in self.region_list[:-1]: + if region.param_size and not region.need_offload: + max_prefetch_profit = (0,) + best_pref_ts = None + + # search when to prefetch the region offloaded + for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]: + if host_region.bwd_prefetch_region is not None: + continue + + temp_ts, profit = self._try_to_offload( + host_region, region) + + if self._compare_profit(profit, max_prefetch_profit): + region_to_region_map[region.r_id] = host_region + max_prefetch_profit = profit + best_pref_ts = temp_ts + if profit[0] == float('inf'): + break + + if self._compare_profit(max_prefetch_profit, max_offload_profit): + region_to_offload = region + max_offload_profit = max_prefetch_profit + best_offl_ts = best_pref_ts + + if (region_to_offload is not None) and (best_offl_ts is not None): + region_to_offload.need_offload = True + if region_to_region_map[region_to_offload.r_id] == region_to_offload: + region_to_offload.is_syn = True + else: + region_to_region_map[region_to_offload.r_id].bwd_prefetch_region = region_to_offload + self.region_to_region_map[region_to_offload.r_id] = region_to_region_map[region_to_offload.r_id] + + self._update_state(best_offl_ts) + + elif self.region_to_region_map.__len__() > 0: + self._repair_strategy() + else: + raise NotImplementedError( + f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " + f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") + + region_to_region_map.clear() + + def _try_to_offload(self, host_region: Region, offload_region: Region): + """ + Attempts to offload the region and prefetch it in backward pass. + """ + + # record previous information + orig_prefetch = host_region.bwd_prefetch_region + orig_is_syn = offload_region.is_syn + orig_need_offload = offload_region.need_offload + + if host_region == offload_region: + offload_region.is_syn = True + else: + host_region.bwd_prefetch_region = offload_region + offload_region.need_offload = True + + ts, profit = self._eval_one_choice() + + # restore previous information + host_region.bwd_prefetch_region = orig_prefetch + offload_region.is_syn = orig_is_syn + offload_region.need_offload = orig_need_offload + + return ts, profit + + def _try_convert_to_syn_upload(self, host_region: Region, offload_region: Region): + """ + Attempts to convert asynchronous prefetch into synchronous upload operations. + """ + + # record previous information + orig_prefetch = host_region.bwd_prefetch_region + orig_is_syn = offload_region.is_syn + assert orig_prefetch is not None and not orig_is_syn + + host_region.bwd_prefetch_region = None + offload_region.is_syn = True + + ts, profit = self._eval_one_choice() + + # restore previous information + host_region.bwd_prefetch_region = orig_prefetch + offload_region.is_syn = orig_is_syn + + return ts, profit + + def _repair_strategy(self): + """ + Repair offload strategy. + It attempts to convert asynchronous prefetch into synchronous upload operations and selects the best one. + The repair process does not end until peak memory is reduced or there is no asynchronous prefetch operation. + """ + print("repair strategy ......") + + peak_mem_saving = 0 + while len(self.region_to_region_map) and peak_mem_saving <= 0: + + max_profit = (0,) + best_ts = None + undo_host_region = None + undo_offload_region = None + + for offload_region_id, host_region in self.region_to_region_map.items(): + offload_region = self.region_list[offload_region_id] + assert host_region.bwd_prefetch_region == offload_region + assert offload_region.need_offload + assert not offload_region.is_syn + + ts, profit = self._try_convert_to_syn_upload(host_region, + offload_region) + + if self._compare_profit(profit, max_profit): + undo_host_region = host_region + undo_offload_region = offload_region + max_profit = profit + best_ts = ts + + if best_ts is None: + raise NotImplementedError('repair error!') + + assert not undo_offload_region.is_syn + undo_offload_region.is_syn = True + undo_host_region.bwd_prefetch_region = None + + peak_mem_saving = self.best_ts.peak_mem - best_ts.peak_mem + + self._update_state(best_ts) + self.region_to_region_map.pop(undo_offload_region.r_id) + + return best_ts + + def _eval_one_choice(self): + """ + Evaluate the profit of a strategy choice. + + Returns: + AsynTrainingSimulator: the training simulator corresponding to the current strategy. + tuple: contains memory saving and cost information of the current strategy. + """ + + ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) + ts.execute() + + extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0) + profit = self._compute_offload_profit( + ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + + return ts, profit + + +class SolverFactory: + solvers: Dict[str, Type[Solver]] = { + 'syn': SynGreedySolver, + 'asyn': AsynGreedySolver + } + + @staticmethod + def create(solver_name: str) -> Type[Solver]: + if solver_name not in SolverFactory.solvers: + raise TypeError(f"Unknown parameter offload policy {solver_name}") + return SolverFactory.solvers[solver_name] + + @staticmethod + def get_solver_names(): + return tuple(SolverFactory.solvers.keys()) diff --git a/colossalai/auto_parallel/offload/training_simulator.py b/colossalai/auto_parallel/offload/training_simulator.py new file mode 100644 index 000000000..f277c183a --- /dev/null +++ b/colossalai/auto_parallel/offload/training_simulator.py @@ -0,0 +1,458 @@ +import bisect +from typing import List, Dict +from collections import OrderedDict +from abc import ABC, abstractmethod + +from torch.fx.node import Node + +from .region import Region +from .util import * + + +@dataclass +class ExecutionPeriod: + start_time: float = 0 + end_time: float = 0 + + +class TrainingSimulator(ABC): + """ + The Training Simulator is used to simulate the training process. + It records computation, communication, and runtime memory during forward and backward passes. + + Args: + region_list (List[Region]): represents the linearized DNN computing graph. + comp_power (float): the NVIDIA GPU FP16 compuing power. + link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth. + """ + + def __init__(self, + region_list: List[Region], + comp_power: float, + link_to_bw: Dict[str, Dict[float, float]]) -> None: + self.region_list = region_list + self.region_num = len(region_list) + + self.runtime_mem: int = 0 + self.peak_mem: int = 0 + self.total_mem_saving: int = 0 + + self.fwd_node_mem: Dict[Node, float] = {} + self.bwd_node_mem: Dict[Node, float] = {} + + # Node dependencies in backward pass + self.bwd_node_deps: Dict[Node, int] = {} + + self.comp_power: float = comp_power + self.link_to_bandwidth: Dict[str, Dict[float, float]] = link_to_bw + + @abstractmethod + def execute(self): + raise NotImplementedError + + @abstractmethod + def _eval_fwd_mem_per_region(self, region: Region): + raise NotImplementedError + + @abstractmethod + def _eval_bwd_mem_per_region(self, region: Region): + raise NotImplementedError + + def _get_bandwidth(self, link: str, comm_volumn: float) -> float: + """ + Get the data transfer bandwidth. + + Args: + link (str): the data transfer link. + comm_volumn (float): the amount of data transferred. + + Returns: + float: the data transfer bandwidth. + """ + + assert len(self.link_to_bandwidth) + if link not in self.link_to_bandwidth: + raise TypeError(f"Unknown data transfer link {link}") + + # size_list = sorted(list(map(float, self.link_to_bandwidth[link].keys()))) + size_list = sorted(self.link_to_bandwidth[link].keys()) + d_idx = bisect.bisect_left(size_list, comm_volumn) + return self.link_to_bandwidth[link][size_list[d_idx]] + + def _get_communication_overhead(self, link: str, comm_volumn: float) -> float: + return comm_volumn / self._get_bandwidth(link, comm_volumn) + + def _get_computing_overhead(self, flop: float) -> float: + return flop / self.comp_power + + +class SynTrainingSimulator(TrainingSimulator): + + def __init__(self, + region_list: List[Region], + comp_power: float, + link_to_bw: Dict[str, Dict[float, float]]) -> None: + super().__init__(region_list, comp_power, link_to_bw) + + def execute(self): + """ + Simulate synchronous training process. + """ + + for reg in self.region_list: + self._eval_fwd_mem_per_region(reg) + + for reg in self.region_list.__reversed__(): + self._eval_bwd_mem_per_region(reg) + + def _eval_fwd_mem_per_region(self, region: Region): + """ + Evaluate the runtime and peak memory when the forward execution reaches the current region. + """ + + # upload parameters of the current region + if requires_upload_p_in_fwd(self.region_list[region.shared_rid]): + self.runtime_mem += region.param_size + + for node in region.nodes: + self.runtime_mem += calculate_fwd_tmp(node) + \ + calculate_fwd_out(node) + self.fwd_node_mem[node] = self.runtime_mem + self.peak_mem = max(self.runtime_mem, self.peak_mem) + self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem + + if region.need_offload: + self.runtime_mem -= region.param_size + + def _eval_bwd_mem_per_region(self, region: Region): + """ + Evaluate the runtime and peak memory when the backward execution reaches the current region. + """ + + # upload parameters of the current region + if region.need_offload: + self.runtime_mem += region.param_size + + # add the gradient of the parameter + if region.r_id < region.shared_rid: + # gradient accumulation is required for shared parameters + self.runtime_mem += 2.0 * region.param_size + else: + self.runtime_mem += region.param_size + + for node in region.nodes.__reversed__(): + + self.runtime_mem -= calculate_fwd_out(node) + self.runtime_mem += node.meta['bwd_mem_tmp'] + \ + node.meta['bwd_mem_out'] + self.peak_mem = max(self.runtime_mem, self.peak_mem) + + # The memory savings of a node may be negative due to parameter prefetch. + self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem + self.bwd_node_mem[node] = self.runtime_mem + + self.runtime_mem -= (node.meta['bwd_mem_tmp'] + + calculate_fwd_tmp(node)) + + # free bwd_mem_out + self.bwd_node_deps[node] = len(node.all_input_nodes) + for user_node in node.users: + if user_node in self.bwd_node_deps: + self.bwd_node_deps[user_node] -= 1 + if self.bwd_node_deps[user_node] <= 0: + self.runtime_mem -= user_node.meta['bwd_mem_out'] + + if self.runtime_mem < 0: + raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " + f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" + f"runtime memory computed less than 0, which is miscalculated!") + + # release parameter and offload gradient in region + if region.r_id == region.shared_rid: + self.runtime_mem -= 2.0 * region.param_size + elif region.r_id < region.shared_rid: + self.runtime_mem -= 3.0 * region.param_size + elif self.region_list[region.shared_rid].need_offload: + self.runtime_mem -= region.param_size + + +class AsynTrainingSimulator(TrainingSimulator): + + def __init__(self, + region_list: List[Region], + comp_power: float, + link_to_bw: Dict[str, Dict[float, float]]) -> None: + super().__init__(region_list, comp_power, link_to_bw) + + self.iter_end_time: int = 0 + # the last computation execution period + self.last_comp: ExecutionPeriod = ExecutionPeriod( + start_time=0, end_time=0) + # the last parameter prefetch execution period + self.last_h2d: ExecutionPeriod = ExecutionPeriod( + start_time=0, end_time=0) + # the last gradient offload execution period + self.last_d2h: ExecutionPeriod = ExecutionPeriod( + start_time=0, end_time=0) + # the forward computation execution period of the region + self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict() + # the forward parameter prefetch execution period of the region + self.fwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict() + # the backward computation execution period of the region + self.bwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict() + # the backward parameter prefetch execution period of the region + self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict() + # the gradient offload execution period of the region + # which is divided into those that are waiting and those that have been released + self.bwd_reg_to_offl_waiting: OrderedDict[int, + ExecutionPeriod] = OrderedDict() + self.bwd_reg_to_offl_freed: OrderedDict[int, + ExecutionPeriod] = OrderedDict() + # the region buffer, which records regions that are offloaded but not released + self.reg_buffer_to_free: List[int] = [] + + # node dependencies in backward pass + self.bwd_node_deps: Dict[Node, int] = {} + + # the region execution flow, + # where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU + # when the execution reaches the i-th region. + self.fwd_reg_flow = torch.zeros( + (self.region_num, self.region_num)).bool() + self.bwd_reg_flow = torch.zeros( + (self.region_num, self.region_num)).bool() + + def execute(self): + """ + Simulate asynchronous training process. + In forward pass, parameter prefetching is advanced by one region. + In backward pass, parameter prefetching is executed at the specified location, + and gradient offloading is urgent. + """ + + for reg in self.region_list: + if reg.param_size and reg.r_id < self.region_num - 1: + for nr in self.region_list[reg.r_id + 1:]: + if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]): + reg.fwd_prefetch_region = nr + break + self._eval_fwd_cost_per_region(reg) + self._eval_fwd_mem_per_region(reg) + + for reg in self.region_list.__reversed__(): + self._eval_bwd_cost_per_region(reg) + self._eval_bwd_mem_per_region(reg) + + # release remaining grads + for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items(): + self.bwd_reg_to_offl_freed[reg_id] = offl_exec + self.runtime_mem -= self.region_list[reg_id].param_size + self.bwd_reg_to_offl_waiting.clear() + + self.iter_end_time = max( + self.last_comp.end_time, self.last_d2h.end_time) + + def _insert_h2d_exec(self, region: Region, is_fwd: bool = True): + """ + Insert parameter prefetch execution period of the current region to the end of the h2d stream + """ + + pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time) + pref_end_time = pref_start_time + \ + 2.0 * self._get_communication_overhead('h2d', region.param_size) + pref_ep = ExecutionPeriod( + start_time=pref_start_time, end_time=pref_end_time) + if is_fwd: + self.fwd_reg_to_pref[region.r_id] = pref_ep + else: + self.bwd_reg_to_pref[region.r_id] = pref_ep + self.last_h2d = pref_ep + + def _insert_comp_exec(self, region: Region, is_fwd: bool = True): + """ + Insert computation execution period of the current region to the end of the computing stream + """ + + if is_fwd: + reg_to_comp = self.fwd_reg_to_comp + reg_to_pref = self.fwd_reg_to_pref + flop_key = 'fwd_flop' + else: + reg_to_comp = self.bwd_reg_to_comp + reg_to_pref = self.bwd_reg_to_pref + flop_key = 'bwd_flop' + comp_start_time = max(self.last_comp.end_time, reg_to_pref.get( + region.r_id, ExecutionPeriod(0, 0)).end_time) + comp_end_time = comp_start_time + \ + sum([self._get_computing_overhead(node.meta.get(flop_key, 0)) + for node in region.nodes]) + comp_ep = ExecutionPeriod( + start_time=comp_start_time, end_time=comp_end_time) + reg_to_comp[region.r_id] = comp_ep + self.last_comp = comp_ep + + def _insert_d2h_exec(self, region: Region): + """ + Insert gradient offload execution period of the current region to the end of the d2h stream + """ + + offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time) + offl_end_time = offl_start_time + \ + self._get_communication_overhead('d2h', region.param_size) + offl_ep = ExecutionPeriod( + start_time=offl_start_time, end_time=offl_end_time) + self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep + self.last_d2h = offl_ep + + def _eval_fwd_cost_per_region(self, region: Region): + """ + Evaluate computation and communication execution period of the region in forward pass. + """ + + # upload parameters of the first region + if region.r_id == 0: + self._insert_h2d_exec(region) + + # prefetch parameters of the next region + fwd_prefetch_region = region.fwd_prefetch_region + if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]): + self._insert_h2d_exec(fwd_prefetch_region) + + # execute computation + self._insert_comp_exec(region) + + def _eval_fwd_mem_per_region(self, region: Region): + """ + Evaluate the runtime and peak memory when the forward execution reaches the current region. + """ + + # upload parameters of the current region + if region.r_id <= 0: + self.runtime_mem += region.param_size + self.fwd_reg_flow[region.r_id, region.r_id] = True + else: + self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1] + self.fwd_reg_flow[region.r_id, + self.reg_buffer_to_free] = False + self.reg_buffer_to_free.clear() + + # prefetch parameters of the next region + fwd_prefetch_region = region.fwd_prefetch_region + if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]): + self.runtime_mem += fwd_prefetch_region.param_size + self.fwd_reg_flow[region.r_id, + fwd_prefetch_region.r_id] = True + + for node in region.nodes: + self.runtime_mem += calculate_fwd_tmp(node) + \ + calculate_fwd_out(node) + self.peak_mem = max(self.runtime_mem, self.peak_mem) + + self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem + self.fwd_node_mem[node] = self.runtime_mem + + if region.need_offload: + self.runtime_mem -= region.param_size + + assert len( + self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}' + self.reg_buffer_to_free.append(region.r_id) + + def _eval_bwd_cost_per_region(self, region: Region): + """ + Evaluate computation and communication execution period of the region in backward pass. + """ + + # upload parameters of the current region + if region.is_syn: + assert region.need_offload + self._insert_h2d_exec(region, is_fwd=False) + + # prefetch parameters of the region choiced, which is parallel to computation + if region.bwd_prefetch_region is not None: + self._insert_h2d_exec(region.bwd_prefetch_region, is_fwd=False) + + # execute computation + self._insert_comp_exec(region, is_fwd=False) + + # offload gradient + if requires_offload_g_in_bwd(region): + self._insert_d2h_exec(region) + + assert len(self.reg_buffer_to_free) == 0 + for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items(): + if offl_exec.end_time >= self.last_comp.start_time: + break + self.reg_buffer_to_free.append(reg_id) + self.bwd_reg_to_offl_freed[reg_id] = offl_exec + + for reg_id in self.reg_buffer_to_free: + self.bwd_reg_to_offl_waiting.pop(reg_id) + + def _eval_bwd_mem_per_region(self, region: Region): + """ + Evaluate the runtime and peak memory when the backward execution reaches the current region. + """ + + if region.r_id + 1 < self.region_num: + self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1] + else: + self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1] + self.bwd_reg_flow[region.r_id, + self.reg_buffer_to_free] = False + + # free gradients in the buffer + while len(self.reg_buffer_to_free): + reg_id = self.reg_buffer_to_free.pop(0) + self.runtime_mem -= self.region_list[reg_id].param_size + + # upload parameters of the current region + if region.is_syn: + self.runtime_mem += region.param_size + self.bwd_reg_flow[region.r_id, region.r_id] = True + + # prefetch parameters of the region choiced + bwd_prefetch_region = region.bwd_prefetch_region + if bwd_prefetch_region: + self.runtime_mem += bwd_prefetch_region.param_size + self.bwd_reg_flow[region.r_id, + bwd_prefetch_region.r_id] = True + + # add the gradient of the parameter + if region.r_id < region.shared_rid: + # gradient accumulation is required for shared parameters + self.runtime_mem += 2.0 * region.param_size + else: + self.runtime_mem += region.param_size + + for node in region.nodes.__reversed__(): + + self.runtime_mem -= calculate_fwd_out(node) + self.runtime_mem += node.meta['bwd_mem_tmp'] + \ + node.meta['bwd_mem_out'] + self.peak_mem = max(self.runtime_mem, self.peak_mem) + + # The memory savings of a node may be negative due to parameter prefetch. + self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem + + self.bwd_node_mem[node] = self.runtime_mem + + self.runtime_mem -= (node.meta['bwd_mem_tmp'] + + calculate_fwd_tmp(node)) + + # free bwd_mem_out + self.bwd_node_deps[node] = len(node.all_input_nodes) + for user_node in node.users: + if user_node in self.bwd_node_deps: + self.bwd_node_deps[user_node] -= 1 + if self.bwd_node_deps[user_node] <= 0: + self.runtime_mem -= user_node.meta['bwd_mem_out'] + + if self.runtime_mem < 0: + raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " + f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" + f"runtime memory computed less than 0, which is miscalculated!") + + # release parameters of the region + if requires_release_p_in_bwd(self.region_list[region.shared_rid]): + self.runtime_mem -= region.param_size diff --git a/colossalai/auto_parallel/offload/util.py b/colossalai/auto_parallel/offload/util.py new file mode 100644 index 000000000..a99c4eb20 --- /dev/null +++ b/colossalai/auto_parallel/offload/util.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass +from typing import List +import torch +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp + +from .region import Region + + +@dataclass +class NodeInfo: + node_id: int = 0 + runtime_fwd_mem: float = 0 + runtime_bwd_mem: float = 0 + +class NvDevicePower: + """ + NVIDIA GPU computing performance (TFLOPs). + """ + + RTX3080_FP16 = 70 + RTX3080_FP32 = 34.1 + + RTX3090_FP16 = 71 + RTX3090_FP32 = 35.7 + + V100_FP16 = 31.4 + V100_FP32 = 15.7 + + A100_FP16 = 78 + A100_FP32 = 19.5 + + +class GlobalRuntimeInfo: + h2d_stream = torch.cuda.Stream() + d2h_stream = torch.cuda.Stream() + fwd_prefetch_event_map = {} + bwd_prefetch_event_map = {} + region_list = [] + + +def compute_act_peak_mem(region_list: List[Region]) -> float: + act_peak_mem = 0 + runtime_mem = 0 + # forward + for region in region_list: + for node in region.nodes: + runtime_mem = runtime_mem + \ + calculate_fwd_tmp(node) + calculate_fwd_out(node) + act_peak_mem = max(runtime_mem, act_peak_mem) + # backward + bwd_deps = {} + for region in region_list.__reversed__(): + for node in region.nodes.__reversed__(): + runtime_mem -= calculate_fwd_out(node) + runtime_mem = runtime_mem + \ + node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out'] + + act_peak_mem = max(runtime_mem, act_peak_mem) + + runtime_mem = runtime_mem - \ + node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node) + + # free bwd_mem_out + bwd_deps[node] = len(node.all_input_nodes) + for user_node in node.users: + if user_node in bwd_deps: + bwd_deps[user_node] -= 1 + if bwd_deps[user_node] <= 0: + runtime_mem -= user_node.meta['bwd_mem_out'] + + return act_peak_mem + +def compute_max_param_mem(region_list: List[Region]) -> float: + return max(region.param_size for region in region_list) + +def compute_total_param_mem(region_list: List[Region]) -> float: + return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid) + +def requires_upload_p_in_fwd(shared_reg: Region): + return (shared_reg.r_id >= shared_reg.shared_rid) or ( + shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload) + +def requires_release_p_in_bwd(shared_reg: Region): + return (shared_reg.r_id >= shared_reg.shared_rid) or ( + shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload) + +def requires_offload_g_in_bwd(region: Region): + return region.param_size and (region.r_id <= region.shared_rid) + + diff --git a/examples/language/gpt/experiments/auto_offload/README.md b/examples/language/gpt/experiments/auto_offload/README.md new file mode 100644 index 000000000..a0d252119 --- /dev/null +++ b/examples/language/gpt/experiments/auto_offload/README.md @@ -0,0 +1,37 @@ +# Auto-Offload Demo with GPT2 + +## Requirements + +Before you can launch training, you need to install the following requirements. + +### Install PyTorch + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website + +```bash +pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org +``` + +### Install transformers + +```bash +pip install transformers +``` + +## Dataset + +For simplicity, the input data is randonly generated here. + +## Training + +```bash +#Run the auto offload on GPT with default setting and a dummy dataset. +bash run.sh +``` diff --git a/examples/language/gpt/experiments/auto_offload/model_zoo.py b/examples/language/gpt/experiments/auto_offload/model_zoo.py new file mode 100644 index 000000000..35e44608f --- /dev/null +++ b/examples/language/gpt/experiments/auto_offload/model_zoo.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel + +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257): + super().__init__() + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + +def get_gpt2_components(model_type: str, batch_size: int): + vocab_size = 1024 + seq_len = 8 + + def gpt2_model_builder(): + if model_type == "gpt2_medium": + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16) + elif model_type == "gpt2_xl": + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32) + elif model_type == "gpt2_10b": + return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16) + elif model_type == "gpt2_14b": + return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16) + elif model_type == "gpt2_20b": + return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16) + elif model_type == "gpt2_24b": + return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16) + else: + raise TypeError(f"model_builder {model_type}") + + def gpt2_data_gen(device="cuda"): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + return gpt2_model_builder, gpt2_data_gen \ No newline at end of file diff --git a/examples/language/gpt/experiments/auto_offload/requirements.txt b/examples/language/gpt/experiments/auto_offload/requirements.txt new file mode 100644 index 000000000..3ebde8d46 --- /dev/null +++ b/examples/language/gpt/experiments/auto_offload/requirements.txt @@ -0,0 +1,2 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 \ No newline at end of file diff --git a/examples/language/gpt/experiments/auto_offload/run.sh b/examples/language/gpt/experiments/auto_offload/run.sh new file mode 100644 index 000000000..6a272ec44 --- /dev/null +++ b/examples/language/gpt/experiments/auto_offload/run.sh @@ -0,0 +1,8 @@ +export BATCH_SIZE=${BATCH_SIZE:-64} +export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} +export MEMORY_BUDGET=${MEMORY_BUDGET:-16} +export SOLVER_TYPE=${SOLVER_TYPE:-"asyn"} + +mkdir -p offload_logs + +python train_gpt_offload.py --model_type=${MODEL_TYPE} --memory_budget=${MEMORY_BUDGET} --solver_type=${SOLVER_TYPE} --batch_size=${BATCH_SIZE} 2>&1 | tee ./offload_logs/${MODEL_TYPE}_bs_${BATCH_SIZE}_st_${SOLVER_TYPE}.log diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py new file mode 100644 index 000000000..729d1ce44 --- /dev/null +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -0,0 +1,94 @@ +import time +import pytest +import argparse +from functools import partial + +import torch +from torch.utils._pytree import tree_map +import torch.multiprocessing as mp + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.fx.profiler import parameter_size +from colossalai.utils import free_port, get_current_device +from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer +from colossalai.auto_parallel.offload.mem_optimize import memory_optimize +from colossalai.auto_parallel.offload.solver import NOT_NVML +from model_zoo import get_gpt2_components, GPTLMLoss + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_type', type=str, default="gpt2_medium") + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--solver_type', type=str, default='asyn') + parser.add_argument('--memory_budget', type=float, default=16) + return parser.parse_args() + +@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +def train_gpt(args): + memory_budget = args.memory_budget * 1024 * 1024 * 1024 + solver_type = args.solver_type + model_type = args.model_type + batch_size = args.batch_size + + # build model + model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size) + label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) + criterion = GPTLMLoss() + + start_time = time.time() + model = model_builder() + model.train() + param_size = parameter_size(model) / 1024 ** 2 / 2 + init_time = time.time() - start_time + print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") + + data_args = data_gen(device="cpu") + wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x + data_args = tree_map(wrap_fn, data_args) + start_time = time.time() + model = memory_optimize(model, data_args, memory_budget, solver_type) + solver_time = time.time() - start_time + print(f"solver_time={solver_time:.3f} s") + + hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3) + optim = AMPOptimizer(hybrid_optimizer, model) + + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + time_list = [] + data_args = data_gen(device="cuda") + data_args = tree_map(wrap_fn, data_args) + for step in range(10): + optim.zero_grad() + torch.cuda.synchronize() + start_time = time.time() + loss = criterion(model(**data_args), label) + optim.backward(loss) + torch.cuda.synchronize() + time_list.append(time.time() - start_time) + optim.step() + + torch.cuda.synchronize() + + exec_time = sum(sorted(time_list)[:5]) / 5 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 + print(f'solver_type: {solver_type} | model_type: {model_type}') + print( + f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' + ) + print(time_list) + +def run(rank, world_size, port, args): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + train_gpt(args) + +if __name__ == '__main__': + args = parse_args() + run_func = partial(run, world_size=1, port=free_port(), args=args) + mp.spawn(run_func, nprocs=1) diff --git a/tests/test_auto_parallel/test_offload/model_utils.py b/tests/test_auto_parallel/test_offload/model_utils.py new file mode 100644 index 000000000..c22b17ae4 --- /dev/null +++ b/tests/test_auto_parallel/test_offload/model_utils.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel +from transformers import BertConfig, BertLMHeadModel +from tests.components_to_test.registry import non_distributed_component_funcs + +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257): + super().__init__() + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] + + +class LMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + +class BertLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522): + super().__init__() + self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size, + num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size, + vocab_size=vocab_size)) + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] + +@non_distributed_component_funcs.register(name='bert_') +def get_bert_components(): + vocab_size = 1024 + seq_len = 64 + batchSize = 64 + + def bert_model_builder(): + model = BertLMModel(hidden_size=8192, num_layers=4, num_attention_heads=32, vocab_size=vocab_size) + return model + + def bert_data_gen(device="meta"): + input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + return bert_model_builder, bert_data_gen + +@non_distributed_component_funcs.register(name='gpt2_') +def get_gpt2_components(): + vocab_size = 1024 + seq_len = 8 + batchSize = 64 + + def gpt2_model_builder(): + model = GPTLMModel(hidden_size=8192, num_layers=2, num_attention_heads=32, vocab_size=vocab_size) + return model + + def gpt2_data_gen(device="meta"): + input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + return gpt2_model_builder, gpt2_data_gen \ No newline at end of file diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py new file mode 100644 index 000000000..d569570f4 --- /dev/null +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -0,0 +1,150 @@ +import time +import pytest +from functools import partial + +import torch +from torch.utils._pytree import tree_map +import torch.multiprocessing as mp + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.fx.profiler import parameter_size +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.utils import free_port, get_current_device +from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer +from colossalai.auto_parallel.offload.mem_optimize import memory_optimize +from colossalai.auto_parallel.offload.solver import NOT_NVML +from colossalai.testing import parameterize + +from tests.test_tensor.common_utils import set_seed +from tests.test_auto_parallel.test_offload.model_utils import * + + +@parameterize('model_name', ['gpt2_']) +@parameterize('memory_budget', [5000]) +@parameterize('solver_name', ['asyn']) +def exam_fwd_bwd( + model_name: str, + memory_budget: float, + solver_name: str +): + + # build model + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, data_gen = get_components_func() + label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) + criterion = LMLoss() + + set_seed(42) + start_time = time.time() + model = model_builder() + model.train() + param_size = parameter_size(model) / 1024 ** 2 / 2 + init_time = time.time() - start_time + print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") + + data_args = data_gen(device="cpu") + wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x + data_args = tree_map(wrap_fn, data_args) + start_time = time.time() + model = memory_optimize(model, data_args, memory_budget * 1024 * 1024, solver_name) + solver_time = time.time() - start_time + print(f"solver_time={solver_time:.3f} s") + + hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3) + optim = AMPOptimizer(hybrid_optimizer, model) + + with ColoInitContext(device=torch.device('cpu')): + gemini_model = model_builder() + gemini_model.train() + + hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) + gemini_config = dict(strict_ddp_mode=False, + device=torch.device('cpu'), + placement_policy='cpu', + pin_memory=True, + hidden_dim=8192, + search_range_mb=128) + gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config) + optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) + gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config) + + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + # test gemini + time_list = [] + set_seed(42) + data_args = data_gen(device="cuda") + for step in range(10): + gemini_optim.zero_grad() + torch.cuda.synchronize() + start_time = time.time() + gemini_out = gemini_model(**data_args) + gemini_loss = criterion(gemini_out, label) + gemini_optim.backward(gemini_loss) + torch.cuda.synchronize() + time_list.append(time.time() - start_time) + gemini_optim.step() + + torch.cuda.synchronize() + + exec_time = sum(sorted(time_list)[:5]) / 5 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 + print(f'gemini | model_name: {model_name}') + print( + f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' + ) + print(time_list) + + del data_args + del gemini_model + del gemini_optim + del gemini_out + del gemini_loss + + # test asyn offload + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + time_list = [] + set_seed(42) + data_args = data_gen(device="cuda") + data_args = tree_map(wrap_fn, data_args) + for step in range(10): + optim.zero_grad() + torch.cuda.synchronize() + start_time = time.time() + loss = criterion(model(**data_args), label) + optim.backward(loss) + torch.cuda.synchronize() + time_list.append(time.time() - start_time) + optim.step() + + torch.cuda.synchronize() + + exec_time = sum(sorted(time_list)[:5]) / 5 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 + print(f'solver_name: {solver_name} | model_name: {model_name}') + print( + f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' + ) + print(time_list) + +@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +def test_perf(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_fwd_bwd() + + +if __name__ == '__main__': + run_func = partial(test_perf, world_size=1, port=free_port()) + mp.spawn(run_func, nprocs=1) diff --git a/tests/test_auto_parallel/test_offload/test_solver.py b/tests/test_auto_parallel/test_offload/test_solver.py new file mode 100644 index 000000000..2efbb750f --- /dev/null +++ b/tests/test_auto_parallel/test_offload/test_solver.py @@ -0,0 +1,62 @@ +import pytest +import torch.fx +from torch.fx import GraphModule +from torch.utils._pytree import tree_map + +from colossalai.fx import ColoTracer, is_compatible_with_meta +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.auto_parallel.offload.region_manager import RegionManager +from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML +from colossalai.testing import parameterize +from tests.test_auto_parallel.test_offload.model_utils import * + +@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@parameterize('model_name', ['gpt2_', 'bert_']) +@parameterize('memory_budget', [4000]) +@parameterize('solver_name', ['syn', 'asyn']) +def solver_test(model_name: str, + memory_budget: float, + solver_name: str): + + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, data_gen = get_components_func() + data_args = data_gen(device="cpu") + wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x + data_args = tree_map(wrap_fn, data_args) + model = model_builder() + model.train() + model = model.cpu().half() + + tracer = ColoTracer() + assert is_compatible_with_meta() + wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x + meta_args = tree_map(wrap_fn, data_args) + graph = tracer.trace(model, meta_args=meta_args) + gm = GraphModule(model, graph, model.__class__.__name__) + + interp = MetaInfoProp(gm) + interp.propagate(*meta_args.values()) + + region_manager = RegionManager(graph, solver_name=solver_name) + region_manager._pre_process() + region_list = region_manager.region_list + + solver_cls = SolverFactory.create(solver_name) + memory_budget = memory_budget * 1024 * 1024 + solver = solver_cls(region_list, memory_budget) + solver._call_solver() + + assert solver.best_ts.peak_mem < memory_budget + + print("****************** execution plan *******************") + for region in region_list: + need_offload = region.need_offload + to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None + print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') + for region in region_list.__reversed__(): + need_offload = region.need_offload + to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None + print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') + +if __name__ == '__main__': + solver_test() \ No newline at end of file