mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* add auto-offload feature * polish code * fix syn offload runtime pass bug * add offload example * fix offload testing bug * fix example testing bugpull/3190/head^2
Zihao
2 years ago
committed by
GitHub
18 changed files with 2833 additions and 0 deletions
@ -0,0 +1,177 @@
|
||||
from typing import Dict, Tuple |
||||
from enum import Enum |
||||
import torch |
||||
from torch.optim import Optimizer |
||||
|
||||
from colossalai.logging import get_dist_logger |
||||
from colossalai.nn.optimizer import ColossalaiOptimizer |
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler |
||||
from colossalai.utils import get_current_device |
||||
|
||||
from .base_offload_module import BaseOffloadModule |
||||
from .region_manager import RegionManager |
||||
from .region import Region |
||||
|
||||
|
||||
class OptimState(Enum): |
||||
SCALED = 0 |
||||
UNSCALED = 1 |
||||
|
||||
class AMPOptimizer(ColossalaiOptimizer): |
||||
|
||||
""" |
||||
A wrapper for Optimizer. |
||||
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py |
||||
|
||||
Args: |
||||
optimizer (Optimizer): An Optimizer instance. |
||||
module (BaseOffloadModule): A ``BaseOffloadModule`` instance. |
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16. |
||||
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. |
||||
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. |
||||
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. |
||||
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. |
||||
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. |
||||
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. |
||||
norm_type (float, optional): norm_type used for `clip_grad_norm`. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
optimizer: Optimizer, |
||||
module: BaseOffloadModule, |
||||
initial_scale: float = 2**16, |
||||
growth_factor: float = 2, |
||||
backoff_factor: float = 0.5, |
||||
growth_interval: int = 1000, |
||||
hysteresis: int = 2, |
||||
min_scale: float = 1, |
||||
max_scale: float = 2**32, |
||||
clipping_norm: float = 0.0, |
||||
norm_type: float = 2.0): |
||||
|
||||
super().__init__(optimizer) |
||||
|
||||
self.module = module |
||||
self.optim_state = OptimState.UNSCALED |
||||
self.clipping_flag = clipping_norm > 0.0 |
||||
self.max_norm = clipping_norm |
||||
|
||||
self.region_manager: RegionManager = self.module.region_manager |
||||
self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict() |
||||
self.param_to_region: Dict[torch.nn.Parameter, Region] = dict() |
||||
|
||||
self.fp32_to_fp16_params: Dict[torch.Tensor, torch.nn.Parameter] = dict() |
||||
|
||||
if self.clipping_flag: |
||||
assert norm_type == 2.0, "AMPOptimizer only supports L2 norm now" |
||||
|
||||
self.__init__optimizer() |
||||
|
||||
# Grad scaler |
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, |
||||
min_scale=min_scale, |
||||
growth_factor=growth_factor, |
||||
backoff_factor=backoff_factor, |
||||
growth_interval=growth_interval, |
||||
hysteresis=hysteresis, |
||||
max_scale=max_scale) |
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) |
||||
self._logger = get_dist_logger() |
||||
|
||||
def _set_grad_ptr(self): |
||||
for group in self.param_groups: |
||||
for fake_param in group['params']: |
||||
region = self.param_to_region[fake_param] |
||||
begin, end = self.param_to_range[fake_param] |
||||
|
||||
fake_param.data = region.cpu_grad[begin:end] |
||||
fake_param.grad = fake_param.data |
||||
fake_param.data = region.fp32_data[begin:end] |
||||
|
||||
def _update_fp16_params(self): |
||||
none_tensor = torch.empty([0]) |
||||
for group in self.param_groups: |
||||
for fake_param in group['params']: |
||||
assert fake_param.grad is None |
||||
fake_param.data = none_tensor |
||||
self.param_to_region[fake_param].cpu_grad = None |
||||
|
||||
def _check_overflow(self): |
||||
# clear previous overflow record |
||||
self._found_overflow.fill_(self.module.overflow_counter.item()) |
||||
return self._found_overflow.item() > 0 |
||||
|
||||
def _get_combined_scale(self): |
||||
loss_scale = 1 |
||||
|
||||
if self.optim_state == OptimState.SCALED: |
||||
loss_scale = self.loss_scale |
||||
self.optim_state = OptimState.UNSCALED |
||||
|
||||
combined_scale = loss_scale |
||||
|
||||
if combined_scale == 1: |
||||
return -1 |
||||
else: |
||||
return combined_scale |
||||
|
||||
@property |
||||
def loss_scale(self): |
||||
return self.grad_scaler.scale.item() |
||||
|
||||
def zero_grad(self, *args, **kwargs): |
||||
self.module.overflow_counter = torch.cuda.IntTensor([0]) |
||||
return self.optim.zero_grad(set_to_none=True) |
||||
|
||||
def step(self, *args, **kwargs): |
||||
# Copy gradients from model params to main params. |
||||
self._set_grad_ptr() |
||||
|
||||
found_inf = self._check_overflow() |
||||
if found_inf: |
||||
self.optim_state = OptimState.UNSCALED # no need to unscale grad |
||||
self.grad_scaler.update(found_inf) # update gradient scaler |
||||
self._logger.info(f'Found overflow. Skip step') |
||||
self.zero_grad() # reset all gradients |
||||
self._update_fp16_params() |
||||
return |
||||
|
||||
# get combined scale. combined scale = loss scale * clipping norm |
||||
# so that gradient = gradient / combined scale |
||||
combined_scale = self._get_combined_scale() |
||||
self.grad_scaler.update(found_inf) |
||||
|
||||
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) |
||||
self.zero_grad() |
||||
self._update_fp16_params() |
||||
return ret |
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): |
||||
raise NotImplementedError |
||||
|
||||
def backward(self, loss: torch.Tensor): |
||||
loss = self.loss_scale * loss |
||||
self.optim_state = OptimState.SCALED |
||||
self.module.backward(loss) |
||||
|
||||
def __init__optimizer(self): |
||||
|
||||
for group in self.optim.param_groups: |
||||
fake_params_list = list() |
||||
|
||||
for param in group['params']: |
||||
region = self.region_manager.get_region(param) |
||||
fake_param = torch.nn.Parameter(torch.empty([0])) |
||||
self.param_to_range[fake_param] = region.param_to_range[param] |
||||
self.param_to_region[fake_param] = region |
||||
fake_params_list.append(fake_param) |
||||
|
||||
# Reset existing state dict key to the new main param. |
||||
if param in self.optim.state: |
||||
self.optim.state[fake_param] = self.optim.state.pop(param) |
||||
|
||||
group['params'] = fake_params_list |
||||
|
||||
# Leverage state_dict() and load_state_dict() to |
||||
# recast preexisting per-param state tensors |
||||
self.optim.load_state_dict(self.optim.state_dict()) |
@ -0,0 +1,109 @@
|
||||
from typing import Optional, Set |
||||
from functools import partial |
||||
import torch |
||||
import torch.nn as nn |
||||
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float |
||||
from colossalai.gemini.tensor_utils import free_storage |
||||
|
||||
from .region_manager import RegionManager |
||||
from .util import GlobalRuntimeInfo |
||||
|
||||
|
||||
class BaseOffloadModule: |
||||
""" |
||||
BaseOffloadModule: A model wrapper for parameter offloading. |
||||
|
||||
Args: |
||||
model (nn.Module): model to apply offloading. |
||||
region_manager (RegionManager): a ``RegionManager`` instance. |
||||
is_sync (bool): synchronous mode or not. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
model: nn.Module, |
||||
region_manager: RegionManager, |
||||
is_sync=True): |
||||
|
||||
self.model = model |
||||
self.region_manager = region_manager |
||||
self.grad_hook_list = [] |
||||
self.overflow_counter = torch.cuda.IntTensor([0]) |
||||
|
||||
self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream |
||||
|
||||
self._cast_buffers() |
||||
|
||||
def register_grad_hook(self): |
||||
for p in self.model.parameters(): |
||||
if p.requires_grad: |
||||
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p))) |
||||
|
||||
def remove_grad_hook(self): |
||||
for hook in self.grad_hook_list: |
||||
hook.remove() |
||||
|
||||
def __call__(self, *args, **kwargs): |
||||
return self.forward(*args, **kwargs) |
||||
|
||||
def _pre_forward(self): |
||||
self.register_grad_hook() |
||||
for region in self.region_manager.region_list: |
||||
region.cpu_grad = None |
||||
|
||||
def forward(self, *args, **kwargs): |
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) |
||||
self.model.zero_grad(set_to_none=True) |
||||
self._pre_forward() |
||||
outputs = self.model(*args, **kwargs) |
||||
return outputs |
||||
|
||||
def backward(self, loss): |
||||
loss.backward() |
||||
self._post_backward() |
||||
|
||||
def _post_backward(self): |
||||
torch.cuda.synchronize() |
||||
self.remove_grad_hook() |
||||
|
||||
for p in self.model.parameters(): |
||||
p.grad = None |
||||
|
||||
GlobalRuntimeInfo.fwd_prefetch_event_map.clear() |
||||
GlobalRuntimeInfo.bwd_prefetch_event_map.clear() |
||||
|
||||
def grad_handle(self, p, grad): |
||||
empty_grad = torch.empty_like(grad) |
||||
free_storage(empty_grad) |
||||
with torch._C.DisableTorchFunction(): |
||||
region = self.region_manager.get_region(p) |
||||
region.copy_grad_to_region_slice(p, grad) |
||||
if region.can_release: |
||||
self.overflow_counter += region.has_inf_or_nan |
||||
master_stream = torch.cuda.current_stream() |
||||
with torch.cuda.stream(self.grad_offload_stream): |
||||
GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream) |
||||
region.move_grad_to_cpu() |
||||
return empty_grad |
||||
|
||||
def _cast_buffers(self): |
||||
for buffer in self.model.buffers(): |
||||
buffer.data = buffer.cuda() |
||||
|
||||
def parameters(self, recurse: bool = True): |
||||
return self.model.parameters(recurse) |
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True): |
||||
return self.model.named_parameters(prefix, recurse) |
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True): |
||||
return self.model.named_buffers(prefix, recurse) |
||||
|
||||
def named_children(self): |
||||
return self.model.named_children() |
||||
|
||||
def named_modules(self, |
||||
memo: Optional[Set[torch.nn.Module]] = None, |
||||
prefix: str = '', |
||||
remove_duplicate: bool = True): |
||||
return self.model.named_modules(memo, prefix, remove_duplicate) |
@ -0,0 +1,49 @@
|
||||
from typing import Dict |
||||
import torch |
||||
import torch.fx |
||||
from torch.fx import GraphModule |
||||
from torch.utils._pytree import tree_map |
||||
|
||||
from colossalai.fx import ColoTracer, is_compatible_with_meta |
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp |
||||
|
||||
from .region_manager import RegionManager |
||||
from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass |
||||
from .base_offload_module import BaseOffloadModule |
||||
from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo |
||||
|
||||
def memory_optimize(model: torch.nn.Module, |
||||
inps: Dict[str, torch.Tensor], |
||||
memory_budget: float = -1.0, |
||||
solver_name: str = 'asyn'): |
||||
|
||||
model = model.cpu().half() |
||||
tracer = ColoTracer() |
||||
assert is_compatible_with_meta() |
||||
wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x |
||||
meta_args = tree_map(wrap_fn, inps) |
||||
graph = tracer.trace(model, meta_args=meta_args) |
||||
gm = GraphModule(model, graph, model.__class__.__name__) |
||||
interp = MetaInfoProp(gm) |
||||
interp.propagate(*meta_args.values()) |
||||
|
||||
region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget) |
||||
region_manager._build_regions() |
||||
GlobalRuntimeInfo.region_list = region_manager.region_list |
||||
|
||||
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2 |
||||
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2 |
||||
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2 |
||||
print( |
||||
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}") |
||||
|
||||
if solver_name == 'syn': |
||||
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list) |
||||
elif solver_name == 'asyn': |
||||
gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list) |
||||
else: |
||||
raise TypeError(f"Unknown solver name {solver_name}!") |
||||
|
||||
gm.recompile() |
||||
optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn') |
||||
return optimized_model |
@ -0,0 +1,144 @@
|
||||
from typing import List, Dict, Tuple |
||||
import torch |
||||
from torch.fx import Node |
||||
from colossalai.gemini.tensor_utils import alloc_storage, free_storage |
||||
|
||||
class Region: |
||||
""" |
||||
Region: A container owning a piece of contiguous nodes in the DNN computing graph. |
||||
|
||||
Args: |
||||
r_id (int): the index of the region in the computing graph. |
||||
""" |
||||
|
||||
def __init__(self, r_id: int = 0) -> None: |
||||
self.r_id: int = r_id |
||||
self.fp16_params: List[torch.nn.Parameter] = [] |
||||
self.param_size: int = 0 |
||||
self.shared_rid: int = self.r_id |
||||
|
||||
self.param_num: int = 0 |
||||
self.grad_num: int = 0 |
||||
self.fp16_data = None |
||||
self.fp32_data = None |
||||
self.cpu_grad = None |
||||
self.temp_fp32_data = None |
||||
self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict() |
||||
|
||||
self.need_offload: bool = False |
||||
self.is_syn: bool = False |
||||
self.nodes: List[Node] = [] |
||||
self.fwd_prefetch_region = None |
||||
self.bwd_prefetch_region = None |
||||
|
||||
self.in_mem_pool_flag: bool = False |
||||
|
||||
@property |
||||
def can_release(self) -> bool: |
||||
""" |
||||
Check if the region can be released. |
||||
""" |
||||
return self.grad_num == self.param_num |
||||
|
||||
@property |
||||
def has_inf_or_nan(self) -> bool: |
||||
""" |
||||
Check if the grad of the region has inf or nan values on CUDA. |
||||
""" |
||||
return torch.isinf(self.fp16_data).any() | torch.isnan(self.fp16_data).any() |
||||
|
||||
def init_param_data(self, pre_alloc_tensor: torch.Tensor = None): |
||||
""" |
||||
Map the parameters in the region to a contiguous memory space. |
||||
""" |
||||
|
||||
self.fp16_data = torch.zeros( |
||||
self.param_num, dtype=torch.half, device='cuda') |
||||
offset = 0 |
||||
for param in self.fp16_params: |
||||
param.data = param.data.cuda() |
||||
p_num = param.data.numel() |
||||
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) |
||||
param.data = self.fp16_data[offset:offset + |
||||
p_num].view(param.data.shape) |
||||
self.param_to_range[param] = (offset, offset + p_num) |
||||
offset += p_num |
||||
|
||||
self.fp32_data = self.fp16_data.float().cpu().pin_memory() |
||||
free_storage(self.fp16_data) |
||||
if self.in_mem_pool_flag and pre_alloc_tensor is not None: |
||||
self.fp16_data = pre_alloc_tensor |
||||
|
||||
def move_param_to_cuda(self): |
||||
""" |
||||
Move parameters from CPU to GPU. |
||||
It first moves float32 parameters to GPU and |
||||
then transforms float32 parameters to half-precision on the GPU. |
||||
The reason is that the performance of precision conversion on the CPU |
||||
is much slower than the data transfer overhead. |
||||
""" |
||||
|
||||
self.temp_fp32_data.copy_(self.fp32_data, non_blocking=True) |
||||
self.temp_fp32_data.record_stream(torch.cuda.current_stream()) |
||||
if not self.in_mem_pool_flag: |
||||
alloc_storage(self.fp16_data) |
||||
self.fp16_data[:self.param_num].copy_(self.temp_fp32_data) |
||||
self.fp16_data.record_stream(torch.cuda.current_stream()) |
||||
|
||||
self.__update_params_ptr() |
||||
|
||||
def move_grad_to_cpu(self): |
||||
""" |
||||
Move gradients from GPU to CPU. |
||||
""" |
||||
|
||||
self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True) |
||||
self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True) |
||||
self.fp16_data.record_stream(torch.cuda.current_stream()) |
||||
if not self.in_mem_pool_flag: |
||||
self.free_cuda_data() |
||||
|
||||
self.grad_num = 0 |
||||
|
||||
def free_cuda_data(self): |
||||
free_storage(self.fp16_data) |
||||
|
||||
# torch.cuda.empty_cache() |
||||
|
||||
def copy_grad_to_region_slice(self, param: torch.nn.Parameter, data_slice: torch.Tensor) -> None: |
||||
""" |
||||
Copy data slice to the memory space indexed by the input tensor in the region. |
||||
|
||||
Args: |
||||
param (torch.nn.Parameter): the param used to retrive meta information |
||||
data_slice (torch.Tensor): the tensor to be copied to the region |
||||
""" |
||||
|
||||
begin, end = self.param_to_range[param] |
||||
self.fp16_data[begin:end].copy_(data_slice.data.flatten()) |
||||
param.data = self.fp16_data[begin:end].view(param.data.shape) |
||||
|
||||
self.grad_num += data_slice.numel() |
||||
|
||||
def split(self, cut_node_idx: int, cut_param_idx: int): |
||||
""" |
||||
Split the region into two and return the latter. |
||||
""" |
||||
new_reg = Region(r_id=self.r_id + 1) |
||||
new_reg.nodes = self.nodes[cut_node_idx:] |
||||
new_reg.fp16_params = self.fp16_params[cut_param_idx:] |
||||
for p in new_reg.fp16_params: |
||||
new_reg.param_size += p.data.numel() * p.data.element_size() |
||||
new_reg.param_num += p.data.numel() |
||||
|
||||
self.nodes = self.nodes[:cut_node_idx] |
||||
self.fp16_params = self.fp16_params[:cut_param_idx] |
||||
self.param_size -= new_reg.param_size |
||||
self.param_num -= new_reg.param_num |
||||
|
||||
return new_reg |
||||
|
||||
def __update_params_ptr(self) -> None: |
||||
for param in self.fp16_params: |
||||
begin, end = self.param_to_range[param] |
||||
param.data = self.fp16_data[begin:end].view(param.data.shape) |
@ -0,0 +1,526 @@
|
||||
from typing import List, Any, Dict, Tuple |
||||
import torch |
||||
from torch.fx import Graph, Node |
||||
|
||||
from .solver import SolverFactory |
||||
from .training_simulator import TrainingSimulator |
||||
from .region import Region |
||||
from .util import NodeInfo |
||||
|
||||
|
||||
class RegionManager: |
||||
""" |
||||
RegionManager is used to construct and manage the offload plan for the model execution. |
||||
|
||||
Args: |
||||
graph (Graph): a Graph object used for analysis and strategy generation. |
||||
solver_name (str): a solver name which specifies the preferences for plan searching. |
||||
memory_budget (float): the given memory budget. |
||||
cnode (List[str], optional): Common node List, should be the subset of input. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
graph: Graph, |
||||
solver_name: str = 'asyn', |
||||
memory_budget: float = -1.0, |
||||
cnode: List[str] = None): |
||||
|
||||
self.graph = graph |
||||
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' |
||||
self.root_module = self.graph.owning_module |
||||
self.nodes = list(graph.nodes) |
||||
self.cnode = cnode |
||||
self.only_param_ops = [] |
||||
self.param_region_map: Dict[torch.nn.Parameter, Region] = dict() |
||||
self.shared_region_pairs: List[Tuple[Region, Region]] = list() |
||||
self.region_list: List[Region] = list() |
||||
self.rid_in_pool: List[int] = list() |
||||
self.mem_block_size: int = 0 |
||||
self.memory_budget = memory_budget |
||||
|
||||
self.solver_name = solver_name |
||||
self.require_pool: bool = solver_name == 'asyn' |
||||
|
||||
self.reg_to_block: Dict[int, int] = dict() |
||||
|
||||
def _build_regions(self): |
||||
""" |
||||
1. Pre-processing, mainly contains linearized computing graph and |
||||
merge smaller regions into larger ones. |
||||
2. Construct a solver to search for an efficient offload strategy. |
||||
3. Post-processing, mainly contains early region placement if using asynchronous mode, |
||||
and initialize region data. |
||||
""" |
||||
|
||||
self._pre_process() |
||||
|
||||
solver_cls = SolverFactory.create(self.solver_name) |
||||
solver = solver_cls(self.region_list, self.memory_budget) |
||||
solver._call_solver() |
||||
|
||||
self._post_process(solver.best_ts) |
||||
|
||||
def _pre_process(self): |
||||
|
||||
init_region_list = self._linearize_graph() |
||||
|
||||
if len(self.shared_region_pairs) > 1: |
||||
raise NotImplementedError( |
||||
'The current version only considers at most one pair of parameter sharing.') |
||||
|
||||
elif len(self.shared_region_pairs) == 1: |
||||
shared_regs = self.shared_region_pairs[0] |
||||
assert shared_regs[0].shared_rid == shared_regs[1].r_id \ |
||||
and shared_regs[1].shared_rid == shared_regs[0].r_id |
||||
fst_id = shared_regs[0].r_id |
||||
lst_id = shared_regs[1].r_id |
||||
regs_left_out = init_region_list[:fst_id + 1] |
||||
regs_right_out = init_region_list[lst_id:] |
||||
hold_regs = init_region_list[fst_id + 1:lst_id] |
||||
else: |
||||
regs_left_out = [] |
||||
regs_right_out = [] |
||||
hold_regs = init_region_list |
||||
|
||||
self.mem_block_size = self._search_block_size(hold_regs) |
||||
hold_regs = self._merge_small_regions(hold_regs) |
||||
|
||||
if self.require_pool: |
||||
for reg in hold_regs: |
||||
reg.in_mem_pool_flag = True |
||||
self.rid_in_pool.append(reg.r_id) |
||||
|
||||
self.region_list.extend(regs_left_out) |
||||
self.region_list.extend(hold_regs) |
||||
|
||||
for reg in regs_right_out: |
||||
reg.r_id = self.region_list[-1].r_id + 1 |
||||
self.region_list[reg.shared_rid].shared_rid = reg.r_id |
||||
self.region_list.append(reg) |
||||
|
||||
self._process_shared_region() |
||||
|
||||
self.max_param_num = max([reg.param_num for reg in self.region_list]) |
||||
self.memory_budget -= self.max_param_num * torch.tensor([], dtype=torch.float32).element_size() |
||||
|
||||
def _post_process(self, ts: TrainingSimulator = None): |
||||
if self.require_pool: |
||||
self._early_region_placement(ts) |
||||
self._init_region_data() |
||||
|
||||
def _early_region_placement(self, ts: TrainingSimulator): |
||||
""" |
||||
Implemented the early region placement strategy to avoid GPU memory fragmentation. |
||||
It maps all region data into a contiguous memory space and |
||||
reuses the same memory space for regions that do not coexist. |
||||
|
||||
Args: |
||||
ts (TrainingSimulator): the best training simulator, which records region execution flow. |
||||
|
||||
Raises: |
||||
NotImplementedError: due to the naive implementation, |
||||
it may not find a suitable region placement strategy for the given execution flow. |
||||
""" |
||||
|
||||
reg_flow = torch.cat( |
||||
[ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) |
||||
mem_block_num = torch.max( |
||||
torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) |
||||
coexist_matrix = torch.logical_or( |
||||
ts.fwd_reg_flow, ts.bwd_reg_flow) |
||||
|
||||
block_to_regs = {} |
||||
for block_idx in range(mem_block_num): |
||||
block_to_regs[block_idx] = [] |
||||
for reg in self.region_list: |
||||
if reg.r_id in self.rid_in_pool: |
||||
cur_reg_appears = coexist_matrix[:, reg.r_id] |
||||
cur_reg_coexists = torch.sum( |
||||
coexist_matrix[cur_reg_appears], dim=0).bool() |
||||
for block_idx in range(mem_block_num): |
||||
if not any(cur_reg_coexists[block_to_regs[block_idx]]): |
||||
block_to_regs[block_idx].append(reg.r_id) |
||||
self.reg_to_block[reg.r_id] = block_idx |
||||
break |
||||
|
||||
if reg.r_id not in self.reg_to_block: |
||||
raise NotImplementedError( |
||||
f'can not find a block from the memory pool to store parameters of the region') |
||||
self.memory_pool = torch.chunk(torch.zeros(int( |
||||
mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num)) |
||||
|
||||
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: |
||||
""" |
||||
Merge smaller regions into larger ones for better bandwidth utilization and easier management. |
||||
It is inspired by Gemini. |
||||
|
||||
Args: |
||||
orig_reg_list (List[Region]): original region list. |
||||
|
||||
Returns: |
||||
List[Region]: region list after merging. |
||||
""" |
||||
|
||||
r_id = orig_reg_list[0].r_id |
||||
region = Region(r_id=r_id) |
||||
region_list = [region] |
||||
|
||||
for orig_reg in orig_reg_list: |
||||
if region_list[-1].param_size + orig_reg.param_size > self.mem_block_size: |
||||
r_id += 1 |
||||
region = Region(r_id=r_id) |
||||
region_list.append(region) |
||||
region.param_size += orig_reg.param_size |
||||
region.param_num += orig_reg.param_num |
||||
region.nodes.extend(orig_reg.nodes) |
||||
region.fp16_params.extend(orig_reg.fp16_params) |
||||
self.__update_param_region_map(orig_reg.fp16_params, region) |
||||
|
||||
return region_list |
||||
|
||||
def _search_block_size(self, |
||||
region_list: List[Region], |
||||
search_interval_byte: int = 1024, |
||||
search_range_byte: int = 128 * 1024 ** 2) -> int: |
||||
""" |
||||
Search for a suitable memory block size. |
||||
|
||||
Args: |
||||
region_list (List[Region]): region list. |
||||
search_interval_byte (int): searching interval in byte. |
||||
search_range_byte (int): searching range in byte. |
||||
|
||||
Returns: |
||||
int: the best memory block size. |
||||
""" |
||||
|
||||
def _get_wasted_mem(size_list: List[int], blk_size: int): |
||||
""" |
||||
Get wasted byte for a certain block size. |
||||
""" |
||||
acc_wasted = 0 |
||||
left = 0 |
||||
for s in size_list: |
||||
if left + s > blk_size: |
||||
acc_wasted += blk_size - left |
||||
left = s |
||||
left += s |
||||
acc_wasted += blk_size - left |
||||
return acc_wasted |
||||
|
||||
param_size_list = [ |
||||
region.param_size for region in region_list if region.r_id == region.shared_rid] |
||||
|
||||
start_size = max(param_size_list) |
||||
min_mem_waste = float('+inf') |
||||
best_block_size = start_size |
||||
|
||||
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): |
||||
temp_waste = 0 |
||||
temp_waste += _get_wasted_mem(param_size_list, block_size) |
||||
if temp_waste < min_mem_waste: |
||||
min_mem_waste = temp_waste |
||||
best_block_size = block_size |
||||
|
||||
return best_block_size |
||||
|
||||
def _init_region_data(self): |
||||
""" |
||||
Initialize region data, which maps the parameters in the region to a contiguous memory space. |
||||
""" |
||||
|
||||
self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32) |
||||
|
||||
for region in self.region_list: |
||||
pre_alloc_tensor = None |
||||
if self.require_pool and region.r_id in self.rid_in_pool: |
||||
block_idx = self.reg_to_block[region.r_id] |
||||
pre_alloc_tensor = self.memory_pool[block_idx] |
||||
|
||||
if region.r_id <= region.shared_rid: |
||||
region.init_param_data(pre_alloc_tensor) |
||||
else: |
||||
shared_region = self.region_list[region.shared_rid] |
||||
region.fp16_data = shared_region.fp16_data |
||||
region.fp32_data = shared_region.fp32_data |
||||
region.param_to_range = shared_region.param_to_range |
||||
region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach( |
||||
) |
||||
|
||||
torch.cuda.empty_cache() |
||||
|
||||
def _process_shared_region(self): |
||||
""" |
||||
Special processing for the shared region, which uses GPT2 and Bert case as a priori knowledge. |
||||
""" |
||||
|
||||
if len(self.shared_region_pairs): |
||||
assert len(self.shared_region_pairs) <= 1 |
||||
former_reg, latter_reg = self.shared_region_pairs[0] |
||||
assert latter_reg.param_num >= former_reg.param_num |
||||
embedding_node = former_reg.nodes[-1] |
||||
assert embedding_node.op == 'call_module' and isinstance( |
||||
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding) |
||||
if latter_reg.param_num > former_reg.param_num: |
||||
for idx, n in enumerate(latter_reg.nodes): |
||||
if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target), |
||||
torch.nn.Linear)) or \ |
||||
(n.op == 'call_function' and n.target is torch.nn.functional.linear): |
||||
cut_node_idx = idx + 1 |
||||
break |
||||
assert len(latter_reg.fp16_params) == 2 |
||||
new_reg = latter_reg.split(cut_node_idx, 1) |
||||
for p in new_reg.fp16_params: |
||||
self.param_region_map[p] = new_reg |
||||
self.region_list.insert(new_reg.r_id, new_reg) |
||||
for reg in self.region_list[new_reg.r_id + 1:]: |
||||
reg.r_id += 1 |
||||
latter_reg.shared_rid = former_reg.r_id |
||||
former_reg.shared_rid = latter_reg.r_id |
||||
|
||||
def _linearize_graph(self) -> List[Region]: |
||||
"""Linearizing the graph |
||||
|
||||
Args: |
||||
graph (Graph): The computing graph to be optimized. |
||||
|
||||
Returns: |
||||
List[Region]: each region contains the actual 'node' in linearized manner. |
||||
|
||||
Remarks: |
||||
Do merge the inplace ops and shape-consistency ops into the previous node. |
||||
""" |
||||
|
||||
# List of target name that could be seen as common node |
||||
common_ops = ["getattr", "getitem", "size"] |
||||
|
||||
def _is_cop(target: Any) -> bool: |
||||
"""Check if an op could be seen as common node |
||||
|
||||
Args: |
||||
target (Any): node target |
||||
|
||||
Returns: |
||||
bool |
||||
""" |
||||
|
||||
if isinstance(target, str): |
||||
return target in common_ops |
||||
else: |
||||
return target.__name__ in common_ops |
||||
|
||||
def _is_act(data: Any) -> bool: |
||||
"""Check if an op could be seen as parameter computation start |
||||
|
||||
Args: |
||||
data (Any): meta_data |
||||
|
||||
Returns: |
||||
bool |
||||
""" |
||||
|
||||
label = False |
||||
if isinstance(data, torch.Tensor): |
||||
return True |
||||
elif isinstance(data, (tuple, list)): |
||||
for d in data: |
||||
label = label or _is_act(d) |
||||
return label |
||||
|
||||
def _maybe_param_comp_start() -> bool: |
||||
"""Check if an op could be seen as parameter computation start |
||||
|
||||
Args: |
||||
n (Node): node |
||||
|
||||
Returns: |
||||
bool |
||||
""" |
||||
|
||||
label = False |
||||
if n.op == "get_attr": |
||||
label = True |
||||
elif n.op == "call_module": |
||||
target = n.target |
||||
submod = self.root_module.get_submodule(target) |
||||
if ( |
||||
len(list(submod.named_parameters(recurse=False))) != 0 |
||||
or len(list(submod.named_buffers(recurse=False))) != 0 |
||||
): |
||||
label = True |
||||
|
||||
return label and not sum([v for _, v in param_op_deps.items()]) |
||||
|
||||
def _is_param_comp_end() -> bool: |
||||
"""Check if an op could be seen as parameter computation end |
||||
|
||||
Args: |
||||
n (Node): node |
||||
|
||||
Returns: |
||||
bool |
||||
""" |
||||
|
||||
def _is_inplace(n: Node): |
||||
"""Get the inplace argument from ``torch.fx.Node`` |
||||
""" |
||||
inplace = False |
||||
if n.op == "call_function": |
||||
inplace = n.kwargs.get("inplace", False) |
||||
elif n.op == "call_module": |
||||
inplace = getattr(n.graph.owning_module.get_submodule( |
||||
n.target), "inplace", False) |
||||
return inplace |
||||
|
||||
label = False |
||||
|
||||
if n.op == "call_module": |
||||
target = n.target |
||||
submod = self.root_module.get_submodule(target) |
||||
if ( |
||||
len(list(submod.named_parameters(recurse=False))) != 0 |
||||
or len(list(submod.named_buffers(recurse=False))) != 0 |
||||
): |
||||
label = True |
||||
|
||||
elif n.op == "call_function": |
||||
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any( |
||||
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)) |
||||
|
||||
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users)) |
||||
|
||||
def _exception_node_handling(): |
||||
# TODO meta info prop bug |
||||
if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2: |
||||
n.meta['fwd_out'] = [] |
||||
|
||||
# make sure that item in cnode is valid |
||||
if self.cnode: |
||||
for name in self.cnode: |
||||
try: |
||||
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ |
||||
f"Common node {name} is not an input of the model." |
||||
except StopIteration: |
||||
raise ValueError(f"Common node name {name} not in graph.") |
||||
else: |
||||
self.cnode = [] |
||||
|
||||
node_id = 0 |
||||
region_id = 0 |
||||
|
||||
param_op_deps = {} |
||||
|
||||
deps = {} |
||||
region_list = [] |
||||
region = Region(r_id=region_id) |
||||
|
||||
act_n = None |
||||
|
||||
for n in self.graph.nodes: |
||||
if n.op != "placeholder" and n.op != "output": |
||||
for n_par in n.all_input_nodes: |
||||
if n_par.op != "placeholder" and n_par.name not in self.cnode: |
||||
deps[n_par] -= 1 |
||||
if n_par.op != "placeholder" and n_par.name in self.only_param_ops: |
||||
param_op_deps[n_par] -= 1 |
||||
|
||||
if act_n in region.nodes and _maybe_param_comp_start(): |
||||
ns = [] |
||||
border_n_idx = region.nodes.index(act_n) |
||||
if border_n_idx < len(region.nodes): |
||||
ns = region.nodes[border_n_idx + 1:] |
||||
region.nodes = region.nodes[:border_n_idx + 1] |
||||
region_list.append(region) |
||||
region_id += 1 |
||||
region = Region(r_id=region_id) |
||||
region.nodes = ns |
||||
|
||||
_exception_node_handling() |
||||
region.nodes.append(n) |
||||
self._set_node_and_region_info(node_id, n, region) |
||||
node_id += 1 |
||||
|
||||
# if the node could free all dependencies in graph |
||||
# we could begin a new region |
||||
if _is_param_comp_end(): |
||||
region_list.append(region) |
||||
region_id += 1 |
||||
region = Region(r_id=region_id) |
||||
|
||||
# propagate common node attr if possible |
||||
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode |
||||
]) or _is_cop(n.target): |
||||
self.cnode.append(n.name) |
||||
else: |
||||
deps[n] = len( |
||||
[user for user in n.users if user.op != "output"]) |
||||
|
||||
# propagate param node attr if possible |
||||
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops |
||||
]) or n.op == "get_attr": |
||||
self.only_param_ops.append(n.name) |
||||
param_op_deps[n] = len( |
||||
[user for user in n.users if user.op != "output"]) |
||||
|
||||
# record last activation node |
||||
if _is_act(n._meta_data): |
||||
act_n = n |
||||
|
||||
if len(region.nodes): |
||||
region_list.append(region) |
||||
|
||||
return region_list |
||||
|
||||
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): |
||||
|
||||
cur_n.node_info = NodeInfo(node_id) |
||||
|
||||
if cur_n.op == 'call_module': |
||||
target = cur_n.target |
||||
submod = self.root_module.get_submodule(target) |
||||
for p in list(submod.parameters(recurse=False)): |
||||
|
||||
if p in self.param_region_map: |
||||
cur_reg.shared_rid = self.param_region_map[p].r_id |
||||
self.param_region_map[p].shared_rid = cur_reg.r_id |
||||
self.shared_region_pairs.append( |
||||
(self.param_region_map[p], cur_reg)) |
||||
else: |
||||
self.param_region_map[p] = cur_reg |
||||
|
||||
cur_reg.fp16_params.append(p) |
||||
cur_reg.param_num += p.data.numel() |
||||
cur_reg.param_size += p.data.numel() * p.data.element_size() |
||||
|
||||
elif cur_n.op == "get_attr": |
||||
attr_itr = self.root_module |
||||
atoms = cur_n.target.split(".") |
||||
for atom in atoms: |
||||
attr_itr = getattr(attr_itr, atom) |
||||
|
||||
if isinstance(attr_itr, torch.nn.Parameter): |
||||
|
||||
if attr_itr in self.param_region_map: |
||||
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id |
||||
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id |
||||
self.shared_region_pairs.append( |
||||
(self.param_region_map[attr_itr], cur_reg)) |
||||
else: |
||||
self.param_region_map[attr_itr] = cur_reg |
||||
|
||||
cur_reg.fp16_params.append(attr_itr) |
||||
cur_reg.param_num += attr_itr.data.numel() |
||||
cur_reg.param_size += attr_itr.data.numel() * attr_itr.data.element_size() |
||||
|
||||
def get_region(self, param: torch.nn.Parameter) -> Region: |
||||
""" |
||||
Return the region owning the parameter. |
||||
|
||||
Args: |
||||
param (torch.nn.Parameter): a torch parameter object |
||||
""" |
||||
return self.param_region_map[param] |
||||
|
||||
def __update_param_region_map(self, params: List[torch.nn.Parameter], region: Region): |
||||
for p in params: |
||||
self.param_region_map[p] = region |
@ -0,0 +1,253 @@
|
||||
from typing import List |
||||
import torch |
||||
from torch.fx.node import Node |
||||
|
||||
from .region import Region |
||||
from .util import GlobalRuntimeInfo, requires_upload_p_in_fwd |
||||
|
||||
|
||||
class SynPreFwdPostBwdOP(torch.autograd.Function): |
||||
""" |
||||
A customized prefetch and offload operation. |
||||
|
||||
Args: |
||||
input_: input tensor. |
||||
fwd_info: information dict, which contains region indices |
||||
that need to be uploaded or freed during forward pass. |
||||
bwd_info: information dict, which contains region indices |
||||
that need to be uploaded during backward pass. |
||||
""" |
||||
|
||||
@staticmethod |
||||
def forward(ctx, input_, fwd_info, bwd_info): |
||||
ctx.bwd_info = bwd_info |
||||
d2h_rid = fwd_info.get('d2h_rid', None) |
||||
if d2h_rid is not None: |
||||
free_region = GlobalRuntimeInfo.region_list[d2h_rid] |
||||
assert isinstance(free_region, Region) |
||||
free_region.free_cuda_data() |
||||
|
||||
h2d_rid = fwd_info.get('h2d_rid', None) |
||||
if h2d_rid is not None: |
||||
h2d_region = GlobalRuntimeInfo.region_list[h2d_rid] |
||||
assert isinstance(h2d_region, Region) |
||||
h2d_region.move_param_to_cuda() |
||||
|
||||
return input_ |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
|
||||
h2d_rid = ctx.bwd_info.get('h2d_rid', None) |
||||
if h2d_rid is not None: |
||||
pref_region = GlobalRuntimeInfo.region_list[h2d_rid] |
||||
assert isinstance(pref_region, Region) |
||||
pref_region.move_param_to_cuda() |
||||
|
||||
return grad_output, None, None |
||||
|
||||
|
||||
class AsynPreFwdPostBwdOP(torch.autograd.Function): |
||||
""" |
||||
A customized prefetch and offload operation. |
||||
|
||||
Args: |
||||
input_: input tensor. |
||||
fwd_info: information dict, which contains region indices |
||||
that need to be prefetched, waited, or freed during forward pass. |
||||
bwd_info: information dict, which contains region indices |
||||
that need to be prefetched or waited during backward pass. |
||||
""" |
||||
|
||||
@staticmethod |
||||
def forward(ctx, input_, fwd_info, bwd_info): |
||||
ctx.bwd_info = bwd_info |
||||
|
||||
sync_rid = fwd_info.get('sync_rid', None) |
||||
if sync_rid is not None: |
||||
prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get( |
||||
sync_rid, None) |
||||
if prefetch_event: |
||||
prefetch_event.wait() |
||||
|
||||
h2d_rid = fwd_info.get('h2d_rid', None) |
||||
if h2d_rid is not None: |
||||
pref_region = GlobalRuntimeInfo.region_list[h2d_rid] |
||||
assert isinstance(pref_region, Region) |
||||
master_stream = torch.cuda.current_stream() |
||||
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream): |
||||
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream) |
||||
pref_region.move_param_to_cuda() |
||||
|
||||
prefetch_event = torch.cuda.Event() |
||||
prefetch_event.record(GlobalRuntimeInfo.h2d_stream) |
||||
GlobalRuntimeInfo.fwd_prefetch_event_map[h2d_rid] = prefetch_event |
||||
|
||||
return input_ |
||||
|
||||
@staticmethod |
||||
def backward(ctx, grad_output): |
||||
|
||||
sync_rid = ctx.bwd_info.get('sync_rid', None) |
||||
if sync_rid is not None: |
||||
wait_region = GlobalRuntimeInfo.region_list[sync_rid] |
||||
assert isinstance(wait_region, Region) |
||||
prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get( |
||||
sync_rid, None) |
||||
if prefetch_event: |
||||
prefetch_event.wait() |
||||
else: |
||||
wait_region.move_param_to_cuda() |
||||
|
||||
h2d_rid = ctx.bwd_info.get('h2d_rid', None) |
||||
if h2d_rid is not None: |
||||
pref_region = GlobalRuntimeInfo.region_list[h2d_rid] |
||||
assert isinstance(pref_region, Region) |
||||
master_stream = torch.cuda.current_stream() |
||||
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream): |
||||
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream) |
||||
pref_region.move_param_to_cuda() |
||||
|
||||
prefetch_event = torch.cuda.Event() |
||||
prefetch_event.record(GlobalRuntimeInfo.h2d_stream) |
||||
GlobalRuntimeInfo.bwd_prefetch_event_map[h2d_rid] = prefetch_event |
||||
return grad_output, None, None |
||||
|
||||
|
||||
def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): |
||||
''' |
||||
Convert Upload and Offload operation into runtime action. |
||||
|
||||
Argument: |
||||
tensor(torch.Tensor): input tensor. |
||||
fwd_info(dict): information dict, which contains region indices |
||||
that need to be uploaded, or freed during forward pass. |
||||
bwd_info(dict): information dict, which contains region indices |
||||
that need to be uploaded during backward pass. |
||||
''' |
||||
with torch._C.DisableTorchFunction(): |
||||
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) |
||||
return ret |
||||
|
||||
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): |
||||
''' |
||||
Convert Prefetch and Offload operation into runtime action. |
||||
|
||||
Argument: |
||||
tensor(torch.Tensor): input tensor. |
||||
fwd_info(dict): information dict, which contains region indices |
||||
that need to be prefetched, waited, or freed during forward pass. |
||||
bwd_info(dict): information dict, which contains region indices |
||||
that need to be prefetched or waited during backward pass. |
||||
''' |
||||
with torch._C.DisableTorchFunction(): |
||||
ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) |
||||
return ret |
||||
|
||||
|
||||
def replace_node_users(orig_node: Node, inserted_node: Node, rep_user_nodes: List[Node] = None): |
||||
user_list = list(orig_node.users.keys()) |
||||
if rep_user_nodes is not None: |
||||
user_list = rep_user_nodes |
||||
for user in user_list: |
||||
if user == inserted_node: |
||||
continue |
||||
new_args = list(user.args) |
||||
new_kwargs = dict(user.kwargs) |
||||
# the origin node may be a positional argument or key word argument of user node |
||||
if orig_node in new_args: |
||||
# substitute the origin node with offload_apply_node |
||||
new_args[new_args.index(orig_node)] = inserted_node |
||||
user.args = tuple(new_args) |
||||
elif str(orig_node) in new_kwargs: |
||||
# substitute the origin node with offload_apply_node |
||||
new_kwargs[str(orig_node)] = inserted_node |
||||
user.kwargs = new_kwargs |
||||
|
||||
|
||||
def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]): |
||||
""" |
||||
This pass is used to add the synchronous upload and offload spec apply node to the origin graph. |
||||
""" |
||||
mod_graph = gm.graph |
||||
last_inp_node = tuple(mod_graph.nodes)[0] |
||||
|
||||
for r_idx, region in enumerate(region_list): |
||||
# forward upload |
||||
fwd_info = {} |
||||
if requires_upload_p_in_fwd(region_list[region.shared_rid]): |
||||
fwd_info['h2d_rid'] = region.r_id |
||||
|
||||
# forward offload |
||||
if r_idx > 0 and region_list[r_idx - 1].need_offload: |
||||
fwd_info['d2h_rid'] = r_idx - 1 |
||||
|
||||
bwd_info = {} |
||||
# backward upload |
||||
if r_idx > 0 and region_list[r_idx - 1].need_offload: |
||||
bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id |
||||
|
||||
if fwd_info or bwd_info: |
||||
with mod_graph.inserting_after(last_inp_node): |
||||
new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action, |
||||
args=(last_inp_node, fwd_info, bwd_info)) |
||||
replace_node_users(last_inp_node, new_node) |
||||
|
||||
last_inp_node = region.nodes[-1] |
||||
|
||||
return gm |
||||
|
||||
|
||||
def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]): |
||||
""" |
||||
This pass is used to add the asynchronous prefetch and offload spec apply node to the origin graph. |
||||
""" |
||||
mod_graph = gm.graph |
||||
|
||||
# upload parameters of the first region |
||||
last_inp_node = tuple(mod_graph.nodes)[0] |
||||
first_region_with_p = [ |
||||
region for region in region_list if region.param_size][0] |
||||
fwd_info = {"h2d_rid": first_region_with_p.r_id} |
||||
with mod_graph.inserting_after(last_inp_node): |
||||
upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action, |
||||
args=(last_inp_node, fwd_info, {})) |
||||
replace_node_users(last_inp_node, upload_apply_node) |
||||
last_inp_node = upload_apply_node |
||||
|
||||
for r_idx, region in enumerate(region_list): |
||||
# forward prefetch |
||||
fwd_info = {} |
||||
if region.param_size: |
||||
fwd_info['sync_rid'] = region.r_id |
||||
fwd_prefetch_region = region.fwd_prefetch_region |
||||
if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]): |
||||
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id |
||||
|
||||
# forward offload |
||||
if r_idx > 0 and region_list[r_idx-1].need_offload: |
||||
fwd_info['d2h_rid'] = r_idx - 1 |
||||
|
||||
bwd_info = {} |
||||
# backward prefetch |
||||
if r_idx > 0 and region_list[r_idx-1].need_offload: |
||||
bwd_info['sync_rid'] = r_idx - 1 |
||||
if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region: |
||||
bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id |
||||
|
||||
if fwd_info or bwd_info: |
||||
with mod_graph.inserting_after(last_inp_node): |
||||
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action, |
||||
args=(last_inp_node, fwd_info, bwd_info)) |
||||
replace_node_users(last_inp_node, new_node) |
||||
|
||||
last_inp_node = region.nodes[-1] |
||||
|
||||
if region.bwd_prefetch_region: |
||||
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id} |
||||
with mod_graph.inserting_after(last_inp_node): |
||||
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action, |
||||
args=(last_inp_node, {}, bwd_info)) |
||||
replace_node_users(last_inp_node, new_node) |
||||
# gm.graph.print_tabular() |
||||
return gm |
@ -0,0 +1,523 @@
|
||||
import time |
||||
from typing import List, Dict, Type |
||||
from abc import ABC, abstractmethod |
||||
|
||||
NOT_NVML = False |
||||
try: |
||||
from pynvml import * |
||||
except: |
||||
NOT_NVML = True |
||||
|
||||
import torch |
||||
from torch.fx.node import Node |
||||
from colossalai.utils.cuda import get_current_device |
||||
|
||||
from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator |
||||
from .region import Region |
||||
from .util import NodeInfo, NvDevicePower |
||||
|
||||
|
||||
def benchmark_func(func, number=1, repeat=1, warmup=3): |
||||
""" |
||||
benchmark data transfer cost. |
||||
""" |
||||
|
||||
for i in range(warmup): |
||||
func() |
||||
|
||||
costs = [] |
||||
|
||||
for i in range(repeat): |
||||
torch.cuda.synchronize() |
||||
begin = time.time() |
||||
for i in range(number): |
||||
func() |
||||
torch.cuda.synchronize() |
||||
costs.append((time.time() - begin) / number) |
||||
|
||||
return sum(costs) / len(costs) |
||||
|
||||
|
||||
class Solver(ABC): |
||||
""" |
||||
The parameter offload solver. |
||||
|
||||
Args: |
||||
region_list (List[Region]): represents the linearized DNN computing graph. |
||||
memory_budget (float): the given memory budget. |
||||
error_factor (float): the error factor. |
||||
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
region_list: List[Region], |
||||
memory_budget: float = -1.0, |
||||
error_factor: float = 0.95) -> None: |
||||
|
||||
self.region_list = region_list |
||||
|
||||
self.error_factor: float = error_factor |
||||
if memory_budget > 0: |
||||
self.memory_budget = memory_budget * self.error_factor |
||||
else: |
||||
self.memory_budget = torch.cuda.get_device_properties( |
||||
get_current_device()).total_memory * self.error_factor |
||||
|
||||
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() |
||||
self.comp_power: float = self._extract_computing_power() |
||||
|
||||
@abstractmethod |
||||
def _call_solver(self): |
||||
raise NotImplementedError |
||||
|
||||
@abstractmethod |
||||
def _try_to_offload(self, *args): |
||||
raise NotImplementedError |
||||
|
||||
@abstractmethod |
||||
def _eval_one_choice(self, *args): |
||||
raise NotImplementedError |
||||
|
||||
def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: float, extra_cost: float): |
||||
""" |
||||
Compute the profits of the offload strategies, |
||||
which packages the memory savings information for subsequent comparisons. |
||||
|
||||
Args: |
||||
total_mem_saving (float): the total memory saving of the offload strategy. |
||||
peak_mem_saving (float): the peak memory saving of the offload strategy. |
||||
extra_cost (float): extra data transfer cost. |
||||
|
||||
Returns: |
||||
tuple: profit information, the first term represents memory savings per unit of time. |
||||
""" |
||||
|
||||
if extra_cost == 0: |
||||
# means data transfer overhead can be completely overlapped |
||||
return (float('inf'), total_mem_saving, peak_mem_saving) |
||||
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving) |
||||
|
||||
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool: |
||||
""" |
||||
Compare the profits of the two offload strategies using the dictionary order algorithm. |
||||
|
||||
Args: |
||||
profit_a (tuple): the profit of a offload strategy. |
||||
profit_b (tuple): the profit of another offload strategy. |
||||
|
||||
Returns: |
||||
bool: whether profit_a is greater than profit_b. |
||||
""" |
||||
|
||||
for val1, val2 in zip(profit_a, profit_b): |
||||
if val1 != val2: |
||||
return val1 > val2 |
||||
return False |
||||
|
||||
def _update_state(self, best_ts: TrainingSimulator): |
||||
""" |
||||
Update the solver state. |
||||
""" |
||||
|
||||
self.best_ts = best_ts |
||||
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem) |
||||
|
||||
def _update_node_mem_info(self, |
||||
fwd_mem_info: Dict[Node, float], |
||||
bwd_mem_info: Dict[Node, float]): |
||||
""" |
||||
Update the runtime memory information of the node. |
||||
|
||||
Args: |
||||
fwd_mem_info (Dict[Node, float]): the runtime memory of each node in forward pass. |
||||
bwd_mem_info (Dict[Node, float]): the runtime memory of each node in backward pass. |
||||
""" |
||||
|
||||
for node, mem in fwd_mem_info.items(): |
||||
assert hasattr(node, 'node_info') and isinstance( |
||||
node.node_info, NodeInfo) |
||||
node.node_info.runtime_fwd_mem = mem |
||||
for node, mem in bwd_mem_info.items(): |
||||
assert hasattr(node, 'node_info') and isinstance( |
||||
node.node_info, NodeInfo) |
||||
node.node_info.runtime_bwd_mem = mem |
||||
|
||||
def _extract_computing_power(self): |
||||
""" |
||||
return the FP16 computing performance of the current NVIDIA GPU. |
||||
|
||||
Raises: |
||||
TypeError: Unknown NVIDIA GPU device. |
||||
""" |
||||
|
||||
nvmlInit() |
||||
handle = nvmlDeviceGetHandleByIndex(0) |
||||
device_name = nvmlDeviceGetName(handle) |
||||
units = 1e12 |
||||
|
||||
if device_name.__contains__("RTX 3080"): |
||||
return NvDevicePower.RTX3080_FP16 * units |
||||
elif device_name.__contains__("RTX 3090"): |
||||
return NvDevicePower.RTX3090_FP16 * units |
||||
elif device_name.__contains__('V100'): |
||||
return NvDevicePower.V100_FP16 * units |
||||
elif device_name.__contains__("A100"): |
||||
return NvDevicePower.A100_FP16 * units |
||||
else: |
||||
raise TypeError(f'Unknown NVIDIA GPU device name {device_name}') |
||||
|
||||
def _profile_bandwidth(self): |
||||
""" |
||||
Profile the bidirectional communication bandwidth between CPU and GPU |
||||
using data volumes ranging from 1KB to 1GB. |
||||
""" |
||||
|
||||
print('profiling bandwidth ......') |
||||
link_to_bandwidth = {} |
||||
links = ['h2d', 'd2h'] |
||||
|
||||
for link in links: |
||||
t_size = 1024 |
||||
size_to_bandwidth = {} |
||||
|
||||
# from 1KB to 1GB |
||||
for i in range(21): |
||||
if link == 'h2d': |
||||
src_tensor = torch.ones( |
||||
int(t_size), dtype=torch.int8, pin_memory=True) |
||||
dst_tensor = torch.ones( |
||||
(int(t_size)), dtype=torch.int8, device='cuda') |
||||
elif link == 'd2h': |
||||
src_tensor = torch.ones( |
||||
int(t_size), dtype=torch.int8, device='cuda') |
||||
dst_tensor = torch.ones( |
||||
(int(t_size)), dtype=torch.int8, pin_memory=True) |
||||
|
||||
def func(): |
||||
dst_tensor.copy_(src_tensor) |
||||
|
||||
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3) |
||||
print(f'size: {t_size / 1024 ** 2:.3f} MB, ' |
||||
f'{src_tensor.device.type}-to-{dst_tensor.device.type} ' |
||||
f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s') |
||||
|
||||
t_size *= 2 |
||||
|
||||
link_to_bandwidth[link] = size_to_bandwidth |
||||
return link_to_bandwidth |
||||
|
||||
|
||||
class SynGreedySolver(Solver): |
||||
|
||||
def __init__(self, |
||||
region_list: List[Region], |
||||
memory_budget: float = -1.0) -> None: |
||||
super().__init__(region_list, memory_budget) |
||||
|
||||
self.best_ts: SynTrainingSimulator = None |
||||
self._init_state() |
||||
|
||||
def _init_state(self): |
||||
""" |
||||
Initialize the solver state when without offloading. |
||||
""" |
||||
|
||||
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) |
||||
ts.execute() |
||||
self._update_state(ts) |
||||
|
||||
def _call_solver(self): |
||||
""" |
||||
Call the solver to search an efficient parameter offloading strategy for the linearized graph. |
||||
The solver adopts greedy algorithm. |
||||
|
||||
Raises: |
||||
NotImplementedError: Unable to find a solution for the given memory budget. |
||||
""" |
||||
|
||||
print("search offloading strategy ......") |
||||
while self.best_ts.peak_mem > self.memory_budget: |
||||
offload_region = None |
||||
best_ts = None |
||||
max_profit = (0,) |
||||
|
||||
# search which region should be offloaded, |
||||
# the last region does not need to be offloaded. |
||||
for region in self.region_list[:-1]: |
||||
if region.param_size and not region.need_offload: |
||||
temp_ts, profit = self._try_to_offload(region) |
||||
if self._compare_profit(profit, max_profit): |
||||
offload_region = region |
||||
max_profit = profit |
||||
best_ts = temp_ts |
||||
|
||||
if offload_region is not None and best_ts is not None: |
||||
offload_region.need_offload = True |
||||
offload_region.is_syn = True |
||||
self._update_state(best_ts) |
||||
else: |
||||
raise NotImplementedError( |
||||
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " |
||||
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") |
||||
|
||||
def _call_solver_l2l(self): |
||||
""" |
||||
The layer-wise offload strategy. |
||||
""" |
||||
|
||||
for region in self.region_list[:-1]: |
||||
region.need_offload = True |
||||
region.is_syn = True |
||||
|
||||
def _try_to_offload(self, offload_region: Region): |
||||
|
||||
# record previous information |
||||
orig_need_offload = offload_region.need_offload |
||||
assert not orig_need_offload |
||||
offload_region.need_offload = True |
||||
|
||||
ts, profit = self._eval_one_choice(offload_region) |
||||
|
||||
# restore previous information |
||||
offload_region.need_offload = orig_need_offload |
||||
return ts, profit |
||||
|
||||
def _eval_one_choice(self, offload_region: Region): |
||||
""" |
||||
Evaluate the profit of a strategy choice. |
||||
|
||||
Args: |
||||
offload_region (Region): the offload region of current choice. |
||||
|
||||
Returns: |
||||
SynTrainingSimulator: the training simulator corresponding to the current strategy. |
||||
tuple: contains memory saving and cost information of the current strategy. |
||||
""" |
||||
|
||||
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) |
||||
ts.execute() |
||||
|
||||
extra_comm_cost = 2.0 * \ |
||||
ts._get_communication_overhead('h2d', offload_region.param_size) |
||||
# the shared region needs to be moved twice |
||||
if offload_region.r_id < offload_region.shared_rid: |
||||
extra_comm_cost *= 2.0 |
||||
profit = self._compute_offload_profit( |
||||
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) |
||||
|
||||
return ts, profit |
||||
|
||||
|
||||
class AsynGreedySolver(Solver): |
||||
|
||||
def __init__(self, |
||||
region_list: List[Region], |
||||
memory_budget: float = -1.0, |
||||
search_window_size: int = 3): |
||||
super().__init__(region_list, memory_budget) |
||||
|
||||
self.search_window_size = search_window_size |
||||
# Records the prefetch execution location of the offloaded region |
||||
self.region_to_region_map = {} |
||||
self.best_ts: AsynTrainingSimulator = None |
||||
|
||||
self._init_state() |
||||
|
||||
def _init_state(self): |
||||
""" |
||||
Initialize the solver state when without offloading. |
||||
""" |
||||
|
||||
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) |
||||
ts.execute() |
||||
self._update_state(ts) |
||||
print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB") |
||||
|
||||
def _call_solver(self): |
||||
""" |
||||
Call the solver to search an efficient parameter offloading strategy for the linearized graph. |
||||
The solver adopts greedy algorithm. |
||||
|
||||
Raises: |
||||
NotImplementedError: Unable to find a solution for the given memory budget. |
||||
""" |
||||
|
||||
print("search for offloading strategy ......") |
||||
# Records the prefetch execution location of the offloaded region |
||||
region_to_region_map = {} |
||||
while self.best_ts.peak_mem > self.memory_budget: |
||||
region_to_offload = None |
||||
max_offload_profit = (0,) |
||||
best_offl_ts = None |
||||
|
||||
# search which region should be offloaded, |
||||
# the last region does not need to be offloaded |
||||
for region in self.region_list[:-1]: |
||||
if region.param_size and not region.need_offload: |
||||
max_prefetch_profit = (0,) |
||||
best_pref_ts = None |
||||
|
||||
# search when to prefetch the region offloaded |
||||
for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]: |
||||
if host_region.bwd_prefetch_region is not None: |
||||
continue |
||||
|
||||
temp_ts, profit = self._try_to_offload( |
||||
host_region, region) |
||||
|
||||
if self._compare_profit(profit, max_prefetch_profit): |
||||
region_to_region_map[region.r_id] = host_region |
||||
max_prefetch_profit = profit |
||||
best_pref_ts = temp_ts |
||||
if profit[0] == float('inf'): |
||||
break |
||||
|
||||
if self._compare_profit(max_prefetch_profit, max_offload_profit): |
||||
region_to_offload = region |
||||
max_offload_profit = max_prefetch_profit |
||||
best_offl_ts = best_pref_ts |
||||
|
||||
if (region_to_offload is not None) and (best_offl_ts is not None): |
||||
region_to_offload.need_offload = True |
||||
if region_to_region_map[region_to_offload.r_id] == region_to_offload: |
||||
region_to_offload.is_syn = True |
||||
else: |
||||
region_to_region_map[region_to_offload.r_id].bwd_prefetch_region = region_to_offload |
||||
self.region_to_region_map[region_to_offload.r_id] = region_to_region_map[region_to_offload.r_id] |
||||
|
||||
self._update_state(best_offl_ts) |
||||
|
||||
elif self.region_to_region_map.__len__() > 0: |
||||
self._repair_strategy() |
||||
else: |
||||
raise NotImplementedError( |
||||
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " |
||||
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") |
||||
|
||||
region_to_region_map.clear() |
||||
|
||||
def _try_to_offload(self, host_region: Region, offload_region: Region): |
||||
""" |
||||
Attempts to offload the region and prefetch it in backward pass. |
||||
""" |
||||
|
||||
# record previous information |
||||
orig_prefetch = host_region.bwd_prefetch_region |
||||
orig_is_syn = offload_region.is_syn |
||||
orig_need_offload = offload_region.need_offload |
||||
|
||||
if host_region == offload_region: |
||||
offload_region.is_syn = True |
||||
else: |
||||
host_region.bwd_prefetch_region = offload_region |
||||
offload_region.need_offload = True |
||||
|
||||
ts, profit = self._eval_one_choice() |
||||
|
||||
# restore previous information |
||||
host_region.bwd_prefetch_region = orig_prefetch |
||||
offload_region.is_syn = orig_is_syn |
||||
offload_region.need_offload = orig_need_offload |
||||
|
||||
return ts, profit |
||||
|
||||
def _try_convert_to_syn_upload(self, host_region: Region, offload_region: Region): |
||||
""" |
||||
Attempts to convert asynchronous prefetch into synchronous upload operations. |
||||
""" |
||||
|
||||
# record previous information |
||||
orig_prefetch = host_region.bwd_prefetch_region |
||||
orig_is_syn = offload_region.is_syn |
||||
assert orig_prefetch is not None and not orig_is_syn |
||||
|
||||
host_region.bwd_prefetch_region = None |
||||
offload_region.is_syn = True |
||||
|
||||
ts, profit = self._eval_one_choice() |
||||
|
||||
# restore previous information |
||||
host_region.bwd_prefetch_region = orig_prefetch |
||||
offload_region.is_syn = orig_is_syn |
||||
|
||||
return ts, profit |
||||
|
||||
def _repair_strategy(self): |
||||
""" |
||||
Repair offload strategy. |
||||
It attempts to convert asynchronous prefetch into synchronous upload operations and selects the best one. |
||||
The repair process does not end until peak memory is reduced or there is no asynchronous prefetch operation. |
||||
""" |
||||
print("repair strategy ......") |
||||
|
||||
peak_mem_saving = 0 |
||||
while len(self.region_to_region_map) and peak_mem_saving <= 0: |
||||
|
||||
max_profit = (0,) |
||||
best_ts = None |
||||
undo_host_region = None |
||||
undo_offload_region = None |
||||
|
||||
for offload_region_id, host_region in self.region_to_region_map.items(): |
||||
offload_region = self.region_list[offload_region_id] |
||||
assert host_region.bwd_prefetch_region == offload_region |
||||
assert offload_region.need_offload |
||||
assert not offload_region.is_syn |
||||
|
||||
ts, profit = self._try_convert_to_syn_upload(host_region, |
||||
offload_region) |
||||
|
||||
if self._compare_profit(profit, max_profit): |
||||
undo_host_region = host_region |
||||
undo_offload_region = offload_region |
||||
max_profit = profit |
||||
best_ts = ts |
||||
|
||||
if best_ts is None: |
||||
raise NotImplementedError('repair error!') |
||||
|
||||
assert not undo_offload_region.is_syn |
||||
undo_offload_region.is_syn = True |
||||
undo_host_region.bwd_prefetch_region = None |
||||
|
||||
peak_mem_saving = self.best_ts.peak_mem - best_ts.peak_mem |
||||
|
||||
self._update_state(best_ts) |
||||
self.region_to_region_map.pop(undo_offload_region.r_id) |
||||
|
||||
return best_ts |
||||
|
||||
def _eval_one_choice(self): |
||||
""" |
||||
Evaluate the profit of a strategy choice. |
||||
|
||||
Returns: |
||||
AsynTrainingSimulator: the training simulator corresponding to the current strategy. |
||||
tuple: contains memory saving and cost information of the current strategy. |
||||
""" |
||||
|
||||
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) |
||||
ts.execute() |
||||
|
||||
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0) |
||||
profit = self._compute_offload_profit( |
||||
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) |
||||
|
||||
return ts, profit |
||||
|
||||
|
||||
class SolverFactory: |
||||
solvers: Dict[str, Type[Solver]] = { |
||||
'syn': SynGreedySolver, |
||||
'asyn': AsynGreedySolver |
||||
} |
||||
|
||||
@staticmethod |
||||
def create(solver_name: str) -> Type[Solver]: |
||||
if solver_name not in SolverFactory.solvers: |
||||
raise TypeError(f"Unknown parameter offload policy {solver_name}") |
||||
return SolverFactory.solvers[solver_name] |
||||
|
||||
@staticmethod |
||||
def get_solver_names(): |
||||
return tuple(SolverFactory.solvers.keys()) |
@ -0,0 +1,458 @@
|
||||
import bisect |
||||
from typing import List, Dict |
||||
from collections import OrderedDict |
||||
from abc import ABC, abstractmethod |
||||
|
||||
from torch.fx.node import Node |
||||
|
||||
from .region import Region |
||||
from .util import * |
||||
|
||||
|
||||
@dataclass |
||||
class ExecutionPeriod: |
||||
start_time: float = 0 |
||||
end_time: float = 0 |
||||
|
||||
|
||||
class TrainingSimulator(ABC): |
||||
""" |
||||
The Training Simulator is used to simulate the training process. |
||||
It records computation, communication, and runtime memory during forward and backward passes. |
||||
|
||||
Args: |
||||
region_list (List[Region]): represents the linearized DNN computing graph. |
||||
comp_power (float): the NVIDIA GPU FP16 compuing power. |
||||
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
region_list: List[Region], |
||||
comp_power: float, |
||||
link_to_bw: Dict[str, Dict[float, float]]) -> None: |
||||
self.region_list = region_list |
||||
self.region_num = len(region_list) |
||||
|
||||
self.runtime_mem: int = 0 |
||||
self.peak_mem: int = 0 |
||||
self.total_mem_saving: int = 0 |
||||
|
||||
self.fwd_node_mem: Dict[Node, float] = {} |
||||
self.bwd_node_mem: Dict[Node, float] = {} |
||||
|
||||
# Node dependencies in backward pass |
||||
self.bwd_node_deps: Dict[Node, int] = {} |
||||
|
||||
self.comp_power: float = comp_power |
||||
self.link_to_bandwidth: Dict[str, Dict[float, float]] = link_to_bw |
||||
|
||||
@abstractmethod |
||||
def execute(self): |
||||
raise NotImplementedError |
||||
|
||||
@abstractmethod |
||||
def _eval_fwd_mem_per_region(self, region: Region): |
||||
raise NotImplementedError |
||||
|
||||
@abstractmethod |
||||
def _eval_bwd_mem_per_region(self, region: Region): |
||||
raise NotImplementedError |
||||
|
||||
def _get_bandwidth(self, link: str, comm_volumn: float) -> float: |
||||
""" |
||||
Get the data transfer bandwidth. |
||||
|
||||
Args: |
||||
link (str): the data transfer link. |
||||
comm_volumn (float): the amount of data transferred. |
||||
|
||||
Returns: |
||||
float: the data transfer bandwidth. |
||||
""" |
||||
|
||||
assert len(self.link_to_bandwidth) |
||||
if link not in self.link_to_bandwidth: |
||||
raise TypeError(f"Unknown data transfer link {link}") |
||||
|
||||
# size_list = sorted(list(map(float, self.link_to_bandwidth[link].keys()))) |
||||
size_list = sorted(self.link_to_bandwidth[link].keys()) |
||||
d_idx = bisect.bisect_left(size_list, comm_volumn) |
||||
return self.link_to_bandwidth[link][size_list[d_idx]] |
||||
|
||||
def _get_communication_overhead(self, link: str, comm_volumn: float) -> float: |
||||
return comm_volumn / self._get_bandwidth(link, comm_volumn) |
||||
|
||||
def _get_computing_overhead(self, flop: float) -> float: |
||||
return flop / self.comp_power |
||||
|
||||
|
||||
class SynTrainingSimulator(TrainingSimulator): |
||||
|
||||
def __init__(self, |
||||
region_list: List[Region], |
||||
comp_power: float, |
||||
link_to_bw: Dict[str, Dict[float, float]]) -> None: |
||||
super().__init__(region_list, comp_power, link_to_bw) |
||||
|
||||
def execute(self): |
||||
""" |
||||
Simulate synchronous training process. |
||||
""" |
||||
|
||||
for reg in self.region_list: |
||||
self._eval_fwd_mem_per_region(reg) |
||||
|
||||
for reg in self.region_list.__reversed__(): |
||||
self._eval_bwd_mem_per_region(reg) |
||||
|
||||
def _eval_fwd_mem_per_region(self, region: Region): |
||||
""" |
||||
Evaluate the runtime and peak memory when the forward execution reaches the current region. |
||||
""" |
||||
|
||||
# upload parameters of the current region |
||||
if requires_upload_p_in_fwd(self.region_list[region.shared_rid]): |
||||
self.runtime_mem += region.param_size |
||||
|
||||
for node in region.nodes: |
||||
self.runtime_mem += calculate_fwd_tmp(node) + \ |
||||
calculate_fwd_out(node) |
||||
self.fwd_node_mem[node] = self.runtime_mem |
||||
self.peak_mem = max(self.runtime_mem, self.peak_mem) |
||||
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem |
||||
|
||||
if region.need_offload: |
||||
self.runtime_mem -= region.param_size |
||||
|
||||
def _eval_bwd_mem_per_region(self, region: Region): |
||||
""" |
||||
Evaluate the runtime and peak memory when the backward execution reaches the current region. |
||||
""" |
||||
|
||||
# upload parameters of the current region |
||||
if region.need_offload: |
||||
self.runtime_mem += region.param_size |
||||
|
||||
# add the gradient of the parameter |
||||
if region.r_id < region.shared_rid: |
||||
# gradient accumulation is required for shared parameters |
||||
self.runtime_mem += 2.0 * region.param_size |
||||
else: |
||||
self.runtime_mem += region.param_size |
||||
|
||||
for node in region.nodes.__reversed__(): |
||||
|
||||
self.runtime_mem -= calculate_fwd_out(node) |
||||
self.runtime_mem += node.meta['bwd_mem_tmp'] + \ |
||||
node.meta['bwd_mem_out'] |
||||
self.peak_mem = max(self.runtime_mem, self.peak_mem) |
||||
|
||||
# The memory savings of a node may be negative due to parameter prefetch. |
||||
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem |
||||
self.bwd_node_mem[node] = self.runtime_mem |
||||
|
||||
self.runtime_mem -= (node.meta['bwd_mem_tmp'] + |
||||
calculate_fwd_tmp(node)) |
||||
|
||||
# free bwd_mem_out |
||||
self.bwd_node_deps[node] = len(node.all_input_nodes) |
||||
for user_node in node.users: |
||||
if user_node in self.bwd_node_deps: |
||||
self.bwd_node_deps[user_node] -= 1 |
||||
if self.bwd_node_deps[user_node] <= 0: |
||||
self.runtime_mem -= user_node.meta['bwd_mem_out'] |
||||
|
||||
if self.runtime_mem < 0: |
||||
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " |
||||
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" |
||||
f"runtime memory computed less than 0, which is miscalculated!") |
||||
|
||||
# release parameter and offload gradient in region |
||||
if region.r_id == region.shared_rid: |
||||
self.runtime_mem -= 2.0 * region.param_size |
||||
elif region.r_id < region.shared_rid: |
||||
self.runtime_mem -= 3.0 * region.param_size |
||||
elif self.region_list[region.shared_rid].need_offload: |
||||
self.runtime_mem -= region.param_size |
||||
|
||||
|
||||
class AsynTrainingSimulator(TrainingSimulator): |
||||
|
||||
def __init__(self, |
||||
region_list: List[Region], |
||||
comp_power: float, |
||||
link_to_bw: Dict[str, Dict[float, float]]) -> None: |
||||
super().__init__(region_list, comp_power, link_to_bw) |
||||
|
||||
self.iter_end_time: int = 0 |
||||
# the last computation execution period |
||||
self.last_comp: ExecutionPeriod = ExecutionPeriod( |
||||
start_time=0, end_time=0) |
||||
# the last parameter prefetch execution period |
||||
self.last_h2d: ExecutionPeriod = ExecutionPeriod( |
||||
start_time=0, end_time=0) |
||||
# the last gradient offload execution period |
||||
self.last_d2h: ExecutionPeriod = ExecutionPeriod( |
||||
start_time=0, end_time=0) |
||||
# the forward computation execution period of the region |
||||
self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict() |
||||
# the forward parameter prefetch execution period of the region |
||||
self.fwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict() |
||||
# the backward computation execution period of the region |
||||
self.bwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict() |
||||
# the backward parameter prefetch execution period of the region |
||||
self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict() |
||||
# the gradient offload execution period of the region |
||||
# which is divided into those that are waiting and those that have been released |
||||
self.bwd_reg_to_offl_waiting: OrderedDict[int, |
||||
ExecutionPeriod] = OrderedDict() |
||||
self.bwd_reg_to_offl_freed: OrderedDict[int, |
||||
ExecutionPeriod] = OrderedDict() |
||||
# the region buffer, which records regions that are offloaded but not released |
||||
self.reg_buffer_to_free: List[int] = [] |
||||
|
||||
# node dependencies in backward pass |
||||
self.bwd_node_deps: Dict[Node, int] = {} |
||||
|
||||
# the region execution flow, |
||||
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU |
||||
# when the execution reaches the i-th region. |
||||
self.fwd_reg_flow = torch.zeros( |
||||
(self.region_num, self.region_num)).bool() |
||||
self.bwd_reg_flow = torch.zeros( |
||||
(self.region_num, self.region_num)).bool() |
||||
|
||||
def execute(self): |
||||
""" |
||||
Simulate asynchronous training process. |
||||
In forward pass, parameter prefetching is advanced by one region. |
||||
In backward pass, parameter prefetching is executed at the specified location, |
||||
and gradient offloading is urgent. |
||||
""" |
||||
|
||||
for reg in self.region_list: |
||||
if reg.param_size and reg.r_id < self.region_num - 1: |
||||
for nr in self.region_list[reg.r_id + 1:]: |
||||
if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]): |
||||
reg.fwd_prefetch_region = nr |
||||
break |
||||
self._eval_fwd_cost_per_region(reg) |
||||
self._eval_fwd_mem_per_region(reg) |
||||
|
||||
for reg in self.region_list.__reversed__(): |
||||
self._eval_bwd_cost_per_region(reg) |
||||
self._eval_bwd_mem_per_region(reg) |
||||
|
||||
# release remaining grads |
||||
for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items(): |
||||
self.bwd_reg_to_offl_freed[reg_id] = offl_exec |
||||
self.runtime_mem -= self.region_list[reg_id].param_size |
||||
self.bwd_reg_to_offl_waiting.clear() |
||||
|
||||
self.iter_end_time = max( |
||||
self.last_comp.end_time, self.last_d2h.end_time) |
||||
|
||||
def _insert_h2d_exec(self, region: Region, is_fwd: bool = True): |
||||
""" |
||||
Insert parameter prefetch execution period of the current region to the end of the h2d stream |
||||
""" |
||||
|
||||
pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time) |
||||
pref_end_time = pref_start_time + \ |
||||
2.0 * self._get_communication_overhead('h2d', region.param_size) |
||||
pref_ep = ExecutionPeriod( |
||||
start_time=pref_start_time, end_time=pref_end_time) |
||||
if is_fwd: |
||||
self.fwd_reg_to_pref[region.r_id] = pref_ep |
||||
else: |
||||
self.bwd_reg_to_pref[region.r_id] = pref_ep |
||||
self.last_h2d = pref_ep |
||||
|
||||
def _insert_comp_exec(self, region: Region, is_fwd: bool = True): |
||||
""" |
||||
Insert computation execution period of the current region to the end of the computing stream |
||||
""" |
||||
|
||||
if is_fwd: |
||||
reg_to_comp = self.fwd_reg_to_comp |
||||
reg_to_pref = self.fwd_reg_to_pref |
||||
flop_key = 'fwd_flop' |
||||
else: |
||||
reg_to_comp = self.bwd_reg_to_comp |
||||
reg_to_pref = self.bwd_reg_to_pref |
||||
flop_key = 'bwd_flop' |
||||
comp_start_time = max(self.last_comp.end_time, reg_to_pref.get( |
||||
region.r_id, ExecutionPeriod(0, 0)).end_time) |
||||
comp_end_time = comp_start_time + \ |
||||
sum([self._get_computing_overhead(node.meta.get(flop_key, 0)) |
||||
for node in region.nodes]) |
||||
comp_ep = ExecutionPeriod( |
||||
start_time=comp_start_time, end_time=comp_end_time) |
||||
reg_to_comp[region.r_id] = comp_ep |
||||
self.last_comp = comp_ep |
||||
|
||||
def _insert_d2h_exec(self, region: Region): |
||||
""" |
||||
Insert gradient offload execution period of the current region to the end of the d2h stream |
||||
""" |
||||
|
||||
offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time) |
||||
offl_end_time = offl_start_time + \ |
||||
self._get_communication_overhead('d2h', region.param_size) |
||||
offl_ep = ExecutionPeriod( |
||||
start_time=offl_start_time, end_time=offl_end_time) |
||||
self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep |
||||
self.last_d2h = offl_ep |
||||
|
||||
def _eval_fwd_cost_per_region(self, region: Region): |
||||
""" |
||||
Evaluate computation and communication execution period of the region in forward pass. |
||||
""" |
||||
|
||||
# upload parameters of the first region |
||||
if region.r_id == 0: |
||||
self._insert_h2d_exec(region) |
||||
|
||||
# prefetch parameters of the next region |
||||
fwd_prefetch_region = region.fwd_prefetch_region |
||||
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]): |
||||
self._insert_h2d_exec(fwd_prefetch_region) |
||||
|
||||
# execute computation |
||||
self._insert_comp_exec(region) |
||||
|
||||
def _eval_fwd_mem_per_region(self, region: Region): |
||||
""" |
||||
Evaluate the runtime and peak memory when the forward execution reaches the current region. |
||||
""" |
||||
|
||||
# upload parameters of the current region |
||||
if region.r_id <= 0: |
||||
self.runtime_mem += region.param_size |
||||
self.fwd_reg_flow[region.r_id, region.r_id] = True |
||||
else: |
||||
self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1] |
||||
self.fwd_reg_flow[region.r_id, |
||||
self.reg_buffer_to_free] = False |
||||
self.reg_buffer_to_free.clear() |
||||
|
||||
# prefetch parameters of the next region |
||||
fwd_prefetch_region = region.fwd_prefetch_region |
||||
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]): |
||||
self.runtime_mem += fwd_prefetch_region.param_size |
||||
self.fwd_reg_flow[region.r_id, |
||||
fwd_prefetch_region.r_id] = True |
||||
|
||||
for node in region.nodes: |
||||
self.runtime_mem += calculate_fwd_tmp(node) + \ |
||||
calculate_fwd_out(node) |
||||
self.peak_mem = max(self.runtime_mem, self.peak_mem) |
||||
|
||||
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem |
||||
self.fwd_node_mem[node] = self.runtime_mem |
||||
|
||||
if region.need_offload: |
||||
self.runtime_mem -= region.param_size |
||||
|
||||
assert len( |
||||
self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}' |
||||
self.reg_buffer_to_free.append(region.r_id) |
||||
|
||||
def _eval_bwd_cost_per_region(self, region: Region): |
||||
""" |
||||
Evaluate computation and communication execution period of the region in backward pass. |
||||
""" |
||||
|
||||
# upload parameters of the current region |
||||
if region.is_syn: |
||||
assert region.need_offload |
||||
self._insert_h2d_exec(region, is_fwd=False) |
||||
|
||||
# prefetch parameters of the region choiced, which is parallel to computation |
||||
if region.bwd_prefetch_region is not None: |
||||
self._insert_h2d_exec(region.bwd_prefetch_region, is_fwd=False) |
||||
|
||||
# execute computation |
||||
self._insert_comp_exec(region, is_fwd=False) |
||||
|
||||
# offload gradient |
||||
if requires_offload_g_in_bwd(region): |
||||
self._insert_d2h_exec(region) |
||||
|
||||
assert len(self.reg_buffer_to_free) == 0 |
||||
for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items(): |
||||
if offl_exec.end_time >= self.last_comp.start_time: |
||||
break |
||||
self.reg_buffer_to_free.append(reg_id) |
||||
self.bwd_reg_to_offl_freed[reg_id] = offl_exec |
||||
|
||||
for reg_id in self.reg_buffer_to_free: |
||||
self.bwd_reg_to_offl_waiting.pop(reg_id) |
||||
|
||||
def _eval_bwd_mem_per_region(self, region: Region): |
||||
""" |
||||
Evaluate the runtime and peak memory when the backward execution reaches the current region. |
||||
""" |
||||
|
||||
if region.r_id + 1 < self.region_num: |
||||
self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1] |
||||
else: |
||||
self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1] |
||||
self.bwd_reg_flow[region.r_id, |
||||
self.reg_buffer_to_free] = False |
||||
|
||||
# free gradients in the buffer |
||||
while len(self.reg_buffer_to_free): |
||||
reg_id = self.reg_buffer_to_free.pop(0) |
||||
self.runtime_mem -= self.region_list[reg_id].param_size |
||||
|
||||
# upload parameters of the current region |
||||
if region.is_syn: |
||||
self.runtime_mem += region.param_size |
||||
self.bwd_reg_flow[region.r_id, region.r_id] = True |
||||
|
||||
# prefetch parameters of the region choiced |
||||
bwd_prefetch_region = region.bwd_prefetch_region |
||||
if bwd_prefetch_region: |
||||
self.runtime_mem += bwd_prefetch_region.param_size |
||||
self.bwd_reg_flow[region.r_id, |
||||
bwd_prefetch_region.r_id] = True |
||||
|
||||
# add the gradient of the parameter |
||||
if region.r_id < region.shared_rid: |
||||
# gradient accumulation is required for shared parameters |
||||
self.runtime_mem += 2.0 * region.param_size |
||||
else: |
||||
self.runtime_mem += region.param_size |
||||
|
||||
for node in region.nodes.__reversed__(): |
||||
|
||||
self.runtime_mem -= calculate_fwd_out(node) |
||||
self.runtime_mem += node.meta['bwd_mem_tmp'] + \ |
||||
node.meta['bwd_mem_out'] |
||||
self.peak_mem = max(self.runtime_mem, self.peak_mem) |
||||
|
||||
# The memory savings of a node may be negative due to parameter prefetch. |
||||
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem |
||||
|
||||
self.bwd_node_mem[node] = self.runtime_mem |
||||
|
||||
self.runtime_mem -= (node.meta['bwd_mem_tmp'] + |
||||
calculate_fwd_tmp(node)) |
||||
|
||||
# free bwd_mem_out |
||||
self.bwd_node_deps[node] = len(node.all_input_nodes) |
||||
for user_node in node.users: |
||||
if user_node in self.bwd_node_deps: |
||||
self.bwd_node_deps[user_node] -= 1 |
||||
if self.bwd_node_deps[user_node] <= 0: |
||||
self.runtime_mem -= user_node.meta['bwd_mem_out'] |
||||
|
||||
if self.runtime_mem < 0: |
||||
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " |
||||
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" |
||||
f"runtime memory computed less than 0, which is miscalculated!") |
||||
|
||||
# release parameters of the region |
||||
if requires_release_p_in_bwd(self.region_list[region.shared_rid]): |
||||
self.runtime_mem -= region.param_size |
@ -0,0 +1,90 @@
|
||||
from dataclasses import dataclass |
||||
from typing import List |
||||
import torch |
||||
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp |
||||
|
||||
from .region import Region |
||||
|
||||
|
||||
@dataclass |
||||
class NodeInfo: |
||||
node_id: int = 0 |
||||
runtime_fwd_mem: float = 0 |
||||
runtime_bwd_mem: float = 0 |
||||
|
||||
class NvDevicePower: |
||||
""" |
||||
NVIDIA GPU computing performance (TFLOPs). |
||||
""" |
||||
|
||||
RTX3080_FP16 = 70 |
||||
RTX3080_FP32 = 34.1 |
||||
|
||||
RTX3090_FP16 = 71 |
||||
RTX3090_FP32 = 35.7 |
||||
|
||||
V100_FP16 = 31.4 |
||||
V100_FP32 = 15.7 |
||||
|
||||
A100_FP16 = 78 |
||||
A100_FP32 = 19.5 |
||||
|
||||
|
||||
class GlobalRuntimeInfo: |
||||
h2d_stream = torch.cuda.Stream() |
||||
d2h_stream = torch.cuda.Stream() |
||||
fwd_prefetch_event_map = {} |
||||
bwd_prefetch_event_map = {} |
||||
region_list = [] |
||||
|
||||
|
||||
def compute_act_peak_mem(region_list: List[Region]) -> float: |
||||
act_peak_mem = 0 |
||||
runtime_mem = 0 |
||||
# forward |
||||
for region in region_list: |
||||
for node in region.nodes: |
||||
runtime_mem = runtime_mem + \ |
||||
calculate_fwd_tmp(node) + calculate_fwd_out(node) |
||||
act_peak_mem = max(runtime_mem, act_peak_mem) |
||||
# backward |
||||
bwd_deps = {} |
||||
for region in region_list.__reversed__(): |
||||
for node in region.nodes.__reversed__(): |
||||
runtime_mem -= calculate_fwd_out(node) |
||||
runtime_mem = runtime_mem + \ |
||||
node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out'] |
||||
|
||||
act_peak_mem = max(runtime_mem, act_peak_mem) |
||||
|
||||
runtime_mem = runtime_mem - \ |
||||
node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node) |
||||
|
||||
# free bwd_mem_out |
||||
bwd_deps[node] = len(node.all_input_nodes) |
||||
for user_node in node.users: |
||||
if user_node in bwd_deps: |
||||
bwd_deps[user_node] -= 1 |
||||
if bwd_deps[user_node] <= 0: |
||||
runtime_mem -= user_node.meta['bwd_mem_out'] |
||||
|
||||
return act_peak_mem |
||||
|
||||
def compute_max_param_mem(region_list: List[Region]) -> float: |
||||
return max(region.param_size for region in region_list) |
||||
|
||||
def compute_total_param_mem(region_list: List[Region]) -> float: |
||||
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid) |
||||
|
||||
def requires_upload_p_in_fwd(shared_reg: Region): |
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or ( |
||||
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload) |
||||
|
||||
def requires_release_p_in_bwd(shared_reg: Region): |
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or ( |
||||
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload) |
||||
|
||||
def requires_offload_g_in_bwd(region: Region): |
||||
return region.param_size and (region.r_id <= region.shared_rid) |
||||
|
||||
|
@ -0,0 +1,37 @@
|
||||
# Auto-Offload Demo with GPT2 |
||||
|
||||
## Requirements |
||||
|
||||
Before you can launch training, you need to install the following requirements. |
||||
|
||||
### Install PyTorch |
||||
|
||||
```bash |
||||
#conda |
||||
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch |
||||
#pip |
||||
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 |
||||
``` |
||||
|
||||
### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website |
||||
|
||||
```bash |
||||
pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org |
||||
``` |
||||
|
||||
### Install transformers |
||||
|
||||
```bash |
||||
pip install transformers |
||||
``` |
||||
|
||||
## Dataset |
||||
|
||||
For simplicity, the input data is randonly generated here. |
||||
|
||||
## Training |
||||
|
||||
```bash |
||||
#Run the auto offload on GPT with default setting and a dummy dataset. |
||||
bash run.sh |
||||
``` |
@ -0,0 +1,65 @@
|
||||
import torch |
||||
import torch.nn as nn |
||||
from transformers import GPT2Config, GPT2LMHeadModel |
||||
|
||||
class GPTLMModel(nn.Module): |
||||
|
||||
def __init__(self, |
||||
hidden_size=768, |
||||
num_layers=12, |
||||
num_attention_heads=12, |
||||
max_seq_len=1024, |
||||
vocab_size=50257): |
||||
super().__init__() |
||||
self.model = GPT2LMHeadModel( |
||||
GPT2Config(n_embd=hidden_size, |
||||
n_layer=num_layers, |
||||
n_head=num_attention_heads, |
||||
n_positions=max_seq_len, |
||||
n_ctx=max_seq_len, |
||||
vocab_size=vocab_size)) |
||||
|
||||
def forward(self, input_ids, attention_mask): |
||||
# Only return lm_logits |
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] |
||||
|
||||
|
||||
class GPTLMLoss(nn.Module): |
||||
|
||||
def __init__(self): |
||||
super().__init__() |
||||
self.loss_fn = nn.CrossEntropyLoss() |
||||
|
||||
def forward(self, logits, labels): |
||||
shift_logits = logits[..., :-1, :].contiguous() |
||||
shift_labels = labels[..., 1:].contiguous() |
||||
# Flatten the tokens |
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
||||
|
||||
def get_gpt2_components(model_type: str, batch_size: int): |
||||
vocab_size = 1024 |
||||
seq_len = 8 |
||||
|
||||
def gpt2_model_builder(): |
||||
if model_type == "gpt2_medium": |
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16) |
||||
elif model_type == "gpt2_xl": |
||||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32) |
||||
elif model_type == "gpt2_10b": |
||||
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16) |
||||
elif model_type == "gpt2_14b": |
||||
return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16) |
||||
elif model_type == "gpt2_20b": |
||||
return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16) |
||||
elif model_type == "gpt2_24b": |
||||
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16) |
||||
else: |
||||
raise TypeError(f"model_builder {model_type}") |
||||
|
||||
def gpt2_data_gen(device="cuda"): |
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) |
||||
attention_mask = torch.ones_like(input_ids, device=device) |
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) |
||||
return kwargs |
||||
|
||||
return gpt2_model_builder, gpt2_data_gen |
@ -0,0 +1,2 @@
|
||||
colossalai >= 0.1.12 |
||||
torch >= 1.8.1 |
@ -0,0 +1,8 @@
|
||||
export BATCH_SIZE=${BATCH_SIZE:-64} |
||||
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} |
||||
export MEMORY_BUDGET=${MEMORY_BUDGET:-16} |
||||
export SOLVER_TYPE=${SOLVER_TYPE:-"asyn"} |
||||
|
||||
mkdir -p offload_logs |
||||
|
||||
python train_gpt_offload.py --model_type=${MODEL_TYPE} --memory_budget=${MEMORY_BUDGET} --solver_type=${SOLVER_TYPE} --batch_size=${BATCH_SIZE} 2>&1 | tee ./offload_logs/${MODEL_TYPE}_bs_${BATCH_SIZE}_st_${SOLVER_TYPE}.log |
@ -0,0 +1,94 @@
|
||||
import time |
||||
import pytest |
||||
import argparse |
||||
from functools import partial |
||||
|
||||
import torch |
||||
from torch.utils._pytree import tree_map |
||||
import torch.multiprocessing as mp |
||||
|
||||
import colossalai |
||||
from colossalai.nn.optimizer import HybridAdam |
||||
from colossalai.fx.profiler import parameter_size |
||||
from colossalai.utils import free_port, get_current_device |
||||
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer |
||||
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize |
||||
from colossalai.auto_parallel.offload.solver import NOT_NVML |
||||
from model_zoo import get_gpt2_components, GPTLMLoss |
||||
|
||||
def parse_args(): |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument('--model_type', type=str, default="gpt2_medium") |
||||
parser.add_argument('--batch_size', type=int, default=64) |
||||
parser.add_argument('--solver_type', type=str, default='asyn') |
||||
parser.add_argument('--memory_budget', type=float, default=16) |
||||
return parser.parse_args() |
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') |
||||
def train_gpt(args): |
||||
memory_budget = args.memory_budget * 1024 * 1024 * 1024 |
||||
solver_type = args.solver_type |
||||
model_type = args.model_type |
||||
batch_size = args.batch_size |
||||
|
||||
# build model |
||||
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size) |
||||
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) |
||||
criterion = GPTLMLoss() |
||||
|
||||
start_time = time.time() |
||||
model = model_builder() |
||||
model.train() |
||||
param_size = parameter_size(model) / 1024 ** 2 / 2 |
||||
init_time = time.time() - start_time |
||||
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") |
||||
|
||||
data_args = data_gen(device="cpu") |
||||
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x |
||||
data_args = tree_map(wrap_fn, data_args) |
||||
start_time = time.time() |
||||
model = memory_optimize(model, data_args, memory_budget, solver_type) |
||||
solver_time = time.time() - start_time |
||||
print(f"solver_time={solver_time:.3f} s") |
||||
|
||||
hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3) |
||||
optim = AMPOptimizer(hybrid_optimizer, model) |
||||
|
||||
torch.cuda.empty_cache() |
||||
torch.cuda.synchronize() |
||||
torch.cuda.reset_peak_memory_stats() |
||||
|
||||
time_list = [] |
||||
data_args = data_gen(device="cuda") |
||||
data_args = tree_map(wrap_fn, data_args) |
||||
for step in range(10): |
||||
optim.zero_grad() |
||||
torch.cuda.synchronize() |
||||
start_time = time.time() |
||||
loss = criterion(model(**data_args), label) |
||||
optim.backward(loss) |
||||
torch.cuda.synchronize() |
||||
time_list.append(time.time() - start_time) |
||||
optim.step() |
||||
|
||||
torch.cuda.synchronize() |
||||
|
||||
exec_time = sum(sorted(time_list)[:5]) / 5 |
||||
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 |
||||
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 |
||||
print(f'solver_type: {solver_type} | model_type: {model_type}') |
||||
print( |
||||
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' |
||||
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' |
||||
) |
||||
print(time_list) |
||||
|
||||
def run(rank, world_size, port, args): |
||||
config = {} |
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
||||
train_gpt(args) |
||||
|
||||
if __name__ == '__main__': |
||||
args = parse_args() |
||||
run_func = partial(run, world_size=1, port=free_port(), args=args) |
||||
mp.spawn(run_func, nprocs=1) |
@ -0,0 +1,86 @@
|
||||
import torch |
||||
import torch.nn as nn |
||||
from transformers import GPT2Config, GPT2LMHeadModel |
||||
from transformers import BertConfig, BertLMHeadModel |
||||
from tests.components_to_test.registry import non_distributed_component_funcs |
||||
|
||||
class GPTLMModel(nn.Module): |
||||
|
||||
def __init__(self, |
||||
hidden_size=768, |
||||
num_layers=12, |
||||
num_attention_heads=12, |
||||
max_seq_len=1024, |
||||
vocab_size=50257): |
||||
super().__init__() |
||||
self.model = GPT2LMHeadModel( |
||||
GPT2Config(n_embd=hidden_size, |
||||
n_layer=num_layers, |
||||
n_head=num_attention_heads, |
||||
n_positions=max_seq_len, |
||||
n_ctx=max_seq_len, |
||||
vocab_size=vocab_size)) |
||||
|
||||
def forward(self, input_ids, attention_mask): |
||||
# Only return lm_logits |
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] |
||||
|
||||
|
||||
class LMLoss(nn.Module): |
||||
|
||||
def __init__(self): |
||||
super().__init__() |
||||
self.loss_fn = nn.CrossEntropyLoss() |
||||
|
||||
def forward(self, logits, labels): |
||||
shift_logits = logits[..., :-1, :].contiguous() |
||||
shift_labels = labels[..., 1:].contiguous() |
||||
# Flatten the tokens |
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
||||
|
||||
class BertLMModel(nn.Module): |
||||
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522): |
||||
super().__init__() |
||||
self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size, |
||||
num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size, |
||||
vocab_size=vocab_size)) |
||||
|
||||
def forward(self, input_ids, attention_mask): |
||||
# Only return lm_logits |
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] |
||||
|
||||
@non_distributed_component_funcs.register(name='bert_') |
||||
def get_bert_components(): |
||||
vocab_size = 1024 |
||||
seq_len = 64 |
||||
batchSize = 64 |
||||
|
||||
def bert_model_builder(): |
||||
model = BertLMModel(hidden_size=8192, num_layers=4, num_attention_heads=32, vocab_size=vocab_size) |
||||
return model |
||||
|
||||
def bert_data_gen(device="meta"): |
||||
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device) |
||||
attention_mask = torch.ones_like(input_ids, device=device) |
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) |
||||
return kwargs |
||||
|
||||
return bert_model_builder, bert_data_gen |
||||
|
||||
@non_distributed_component_funcs.register(name='gpt2_') |
||||
def get_gpt2_components(): |
||||
vocab_size = 1024 |
||||
seq_len = 8 |
||||
batchSize = 64 |
||||
|
||||
def gpt2_model_builder(): |
||||
model = GPTLMModel(hidden_size=8192, num_layers=2, num_attention_heads=32, vocab_size=vocab_size) |
||||
return model |
||||
|
||||
def gpt2_data_gen(device="meta"): |
||||
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device) |
||||
attention_mask = torch.ones_like(input_ids, device=device) |
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) |
||||
return kwargs |
||||
|
||||
return gpt2_model_builder, gpt2_data_gen |
@ -0,0 +1,150 @@
|
||||
import time |
||||
import pytest |
||||
from functools import partial |
||||
|
||||
import torch |
||||
from torch.utils._pytree import tree_map |
||||
import torch.multiprocessing as mp |
||||
|
||||
import colossalai |
||||
from colossalai.nn.optimizer import HybridAdam |
||||
from colossalai.fx.profiler import parameter_size |
||||
from colossalai.utils.model.colo_init_context import ColoInitContext |
||||
from colossalai.utils import free_port, get_current_device |
||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper |
||||
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer |
||||
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize |
||||
from colossalai.auto_parallel.offload.solver import NOT_NVML |
||||
from colossalai.testing import parameterize |
||||
|
||||
from tests.test_tensor.common_utils import set_seed |
||||
from tests.test_auto_parallel.test_offload.model_utils import * |
||||
|
||||
|
||||
@parameterize('model_name', ['gpt2_']) |
||||
@parameterize('memory_budget', [5000]) |
||||
@parameterize('solver_name', ['asyn']) |
||||
def exam_fwd_bwd( |
||||
model_name: str, |
||||
memory_budget: float, |
||||
solver_name: str |
||||
): |
||||
|
||||
# build model |
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name) |
||||
model_builder, data_gen = get_components_func() |
||||
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) |
||||
criterion = LMLoss() |
||||
|
||||
set_seed(42) |
||||
start_time = time.time() |
||||
model = model_builder() |
||||
model.train() |
||||
param_size = parameter_size(model) / 1024 ** 2 / 2 |
||||
init_time = time.time() - start_time |
||||
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") |
||||
|
||||
data_args = data_gen(device="cpu") |
||||
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x |
||||
data_args = tree_map(wrap_fn, data_args) |
||||
start_time = time.time() |
||||
model = memory_optimize(model, data_args, memory_budget * 1024 * 1024, solver_name) |
||||
solver_time = time.time() - start_time |
||||
print(f"solver_time={solver_time:.3f} s") |
||||
|
||||
hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3) |
||||
optim = AMPOptimizer(hybrid_optimizer, model) |
||||
|
||||
with ColoInitContext(device=torch.device('cpu')): |
||||
gemini_model = model_builder() |
||||
gemini_model.train() |
||||
|
||||
hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) |
||||
gemini_config = dict(strict_ddp_mode=False, |
||||
device=torch.device('cpu'), |
||||
placement_policy='cpu', |
||||
pin_memory=True, |
||||
hidden_dim=8192, |
||||
search_range_mb=128) |
||||
gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config) |
||||
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) |
||||
gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config) |
||||
|
||||
torch.cuda.empty_cache() |
||||
torch.cuda.synchronize() |
||||
torch.cuda.reset_peak_memory_stats() |
||||
|
||||
# test gemini |
||||
time_list = [] |
||||
set_seed(42) |
||||
data_args = data_gen(device="cuda") |
||||
for step in range(10): |
||||
gemini_optim.zero_grad() |
||||
torch.cuda.synchronize() |
||||
start_time = time.time() |
||||
gemini_out = gemini_model(**data_args) |
||||
gemini_loss = criterion(gemini_out, label) |
||||
gemini_optim.backward(gemini_loss) |
||||
torch.cuda.synchronize() |
||||
time_list.append(time.time() - start_time) |
||||
gemini_optim.step() |
||||
|
||||
torch.cuda.synchronize() |
||||
|
||||
exec_time = sum(sorted(time_list)[:5]) / 5 |
||||
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 |
||||
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 |
||||
print(f'gemini | model_name: {model_name}') |
||||
print( |
||||
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' |
||||
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' |
||||
) |
||||
print(time_list) |
||||
|
||||
del data_args |
||||
del gemini_model |
||||
del gemini_optim |
||||
del gemini_out |
||||
del gemini_loss |
||||
|
||||
# test asyn offload |
||||
torch.cuda.empty_cache() |
||||
torch.cuda.synchronize() |
||||
torch.cuda.reset_peak_memory_stats() |
||||
|
||||
time_list = [] |
||||
set_seed(42) |
||||
data_args = data_gen(device="cuda") |
||||
data_args = tree_map(wrap_fn, data_args) |
||||
for step in range(10): |
||||
optim.zero_grad() |
||||
torch.cuda.synchronize() |
||||
start_time = time.time() |
||||
loss = criterion(model(**data_args), label) |
||||
optim.backward(loss) |
||||
torch.cuda.synchronize() |
||||
time_list.append(time.time() - start_time) |
||||
optim.step() |
||||
|
||||
torch.cuda.synchronize() |
||||
|
||||
exec_time = sum(sorted(time_list)[:5]) / 5 |
||||
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 |
||||
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 |
||||
print(f'solver_name: {solver_name} | model_name: {model_name}') |
||||
print( |
||||
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' |
||||
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' |
||||
) |
||||
print(time_list) |
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') |
||||
def test_perf(rank, world_size, port): |
||||
config = {} |
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
||||
exam_fwd_bwd() |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
run_func = partial(test_perf, world_size=1, port=free_port()) |
||||
mp.spawn(run_func, nprocs=1) |
@ -0,0 +1,62 @@
|
||||
import pytest |
||||
import torch.fx |
||||
from torch.fx import GraphModule |
||||
from torch.utils._pytree import tree_map |
||||
|
||||
from colossalai.fx import ColoTracer, is_compatible_with_meta |
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp |
||||
from colossalai.auto_parallel.offload.region_manager import RegionManager |
||||
from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML |
||||
from colossalai.testing import parameterize |
||||
from tests.test_auto_parallel.test_offload.model_utils import * |
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') |
||||
@parameterize('model_name', ['gpt2_', 'bert_']) |
||||
@parameterize('memory_budget', [4000]) |
||||
@parameterize('solver_name', ['syn', 'asyn']) |
||||
def solver_test(model_name: str, |
||||
memory_budget: float, |
||||
solver_name: str): |
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name) |
||||
model_builder, data_gen = get_components_func() |
||||
data_args = data_gen(device="cpu") |
||||
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x |
||||
data_args = tree_map(wrap_fn, data_args) |
||||
model = model_builder() |
||||
model.train() |
||||
model = model.cpu().half() |
||||
|
||||
tracer = ColoTracer() |
||||
assert is_compatible_with_meta() |
||||
wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x |
||||
meta_args = tree_map(wrap_fn, data_args) |
||||
graph = tracer.trace(model, meta_args=meta_args) |
||||
gm = GraphModule(model, graph, model.__class__.__name__) |
||||
|
||||
interp = MetaInfoProp(gm) |
||||
interp.propagate(*meta_args.values()) |
||||
|
||||
region_manager = RegionManager(graph, solver_name=solver_name) |
||||
region_manager._pre_process() |
||||
region_list = region_manager.region_list |
||||
|
||||
solver_cls = SolverFactory.create(solver_name) |
||||
memory_budget = memory_budget * 1024 * 1024 |
||||
solver = solver_cls(region_list, memory_budget) |
||||
solver._call_solver() |
||||
|
||||
assert solver.best_ts.peak_mem < memory_budget |
||||
|
||||
print("****************** execution plan *******************") |
||||
for region in region_list: |
||||
need_offload = region.need_offload |
||||
to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None |
||||
print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') |
||||
for region in region_list.__reversed__(): |
||||
need_offload = region.need_offload |
||||
to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None |
||||
print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') |
||||
|
||||
if __name__ == '__main__': |
||||
solver_test() |
Loading…
Reference in new issue