mirror of https://github.com/hpcaitech/ColossalAI
[auto-parallel] add auto-offload feature (#3154)
* add auto-offload feature * polish code * fix syn offload runtime pass bug * add offload example * fix offload testing bug * fix example testing bugpull/3190/head^2
parent
258b43317c
commit
18dbe76cae
|
@ -0,0 +1,177 @@
|
|||
from typing import Dict, Tuple
|
||||
from enum import Enum
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base_offload_module import BaseOffloadModule
|
||||
from .region_manager import RegionManager
|
||||
from .region import Region
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
SCALED = 0
|
||||
UNSCALED = 1
|
||||
|
||||
class AMPOptimizer(ColossalaiOptimizer):
|
||||
|
||||
"""
|
||||
A wrapper for Optimizer.
|
||||
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): An Optimizer instance.
|
||||
module (BaseOffloadModule): A ``BaseOffloadModule`` instance.
|
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
|
||||
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
||||
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
||||
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
|
||||
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
|
||||
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
||||
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
module: BaseOffloadModule,
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
clipping_norm: float = 0.0,
|
||||
norm_type: float = 2.0):
|
||||
|
||||
super().__init__(optimizer)
|
||||
|
||||
self.module = module
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.clipping_flag = clipping_norm > 0.0
|
||||
self.max_norm = clipping_norm
|
||||
|
||||
self.region_manager: RegionManager = self.module.region_manager
|
||||
self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict()
|
||||
self.param_to_region: Dict[torch.nn.Parameter, Region] = dict()
|
||||
|
||||
self.fp32_to_fp16_params: Dict[torch.Tensor, torch.nn.Parameter] = dict()
|
||||
|
||||
if self.clipping_flag:
|
||||
assert norm_type == 2.0, "AMPOptimizer only supports L2 norm now"
|
||||
|
||||
self.__init__optimizer()
|
||||
|
||||
# Grad scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def _set_grad_ptr(self):
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
region = self.param_to_region[fake_param]
|
||||
begin, end = self.param_to_range[fake_param]
|
||||
|
||||
fake_param.data = region.cpu_grad[begin:end]
|
||||
fake_param.grad = fake_param.data
|
||||
fake_param.data = region.fp32_data[begin:end]
|
||||
|
||||
def _update_fp16_params(self):
|
||||
none_tensor = torch.empty([0])
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
assert fake_param.grad is None
|
||||
fake_param.data = none_tensor
|
||||
self.param_to_region[fake_param].cpu_grad = None
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(self.module.overflow_counter.item())
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def _get_combined_scale(self):
|
||||
loss_scale = 1
|
||||
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
loss_scale = self.loss_scale
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
|
||||
combined_scale = loss_scale
|
||||
|
||||
if combined_scale == 1:
|
||||
return -1
|
||||
else:
|
||||
return combined_scale
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale.item()
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self.module.overflow_counter = torch.cuda.IntTensor([0])
|
||||
return self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
# Copy gradients from model params to main params.
|
||||
self._set_grad_ptr()
|
||||
|
||||
found_inf = self._check_overflow()
|
||||
if found_inf:
|
||||
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
||||
self.grad_scaler.update(found_inf) # update gradient scaler
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self.zero_grad() # reset all gradients
|
||||
self._update_fp16_params()
|
||||
return
|
||||
|
||||
# get combined scale. combined scale = loss scale * clipping norm
|
||||
# so that gradient = gradient / combined scale
|
||||
combined_scale = self._get_combined_scale()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
|
||||
self.zero_grad()
|
||||
self._update_fp16_params()
|
||||
return ret
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
||||
raise NotImplementedError
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
self.module.backward(loss)
|
||||
|
||||
def __init__optimizer(self):
|
||||
|
||||
for group in self.optim.param_groups:
|
||||
fake_params_list = list()
|
||||
|
||||
for param in group['params']:
|
||||
region = self.region_manager.get_region(param)
|
||||
fake_param = torch.nn.Parameter(torch.empty([0]))
|
||||
self.param_to_range[fake_param] = region.param_to_range[param]
|
||||
self.param_to_region[fake_param] = region
|
||||
fake_params_list.append(fake_param)
|
||||
|
||||
# Reset existing state dict key to the new main param.
|
||||
if param in self.optim.state:
|
||||
self.optim.state[fake_param] = self.optim.state.pop(param)
|
||||
|
||||
group['params'] = fake_params_list
|
||||
|
||||
# Leverage state_dict() and load_state_dict() to
|
||||
# recast preexisting per-param state tensors
|
||||
self.optim.load_state_dict(self.optim.state_dict())
|
|
@ -0,0 +1,109 @@
|
|||
from typing import Optional, Set
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from colossalai.gemini.tensor_utils import free_storage
|
||||
|
||||
from .region_manager import RegionManager
|
||||
from .util import GlobalRuntimeInfo
|
||||
|
||||
|
||||
class BaseOffloadModule:
|
||||
"""
|
||||
BaseOffloadModule: A model wrapper for parameter offloading.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to apply offloading.
|
||||
region_manager (RegionManager): a ``RegionManager`` instance.
|
||||
is_sync (bool): synchronous mode or not.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
region_manager: RegionManager,
|
||||
is_sync=True):
|
||||
|
||||
self.model = model
|
||||
self.region_manager = region_manager
|
||||
self.grad_hook_list = []
|
||||
self.overflow_counter = torch.cuda.IntTensor([0])
|
||||
|
||||
self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream
|
||||
|
||||
self._cast_buffers()
|
||||
|
||||
def register_grad_hook(self):
|
||||
for p in self.model.parameters():
|
||||
if p.requires_grad:
|
||||
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
|
||||
|
||||
def remove_grad_hook(self):
|
||||
for hook in self.grad_hook_list:
|
||||
hook.remove()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def _pre_forward(self):
|
||||
self.register_grad_hook()
|
||||
for region in self.region_manager.region_list:
|
||||
region.cpu_grad = None
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.model.zero_grad(set_to_none=True)
|
||||
self._pre_forward()
|
||||
outputs = self.model(*args, **kwargs)
|
||||
return outputs
|
||||
|
||||
def backward(self, loss):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
||||
def _post_backward(self):
|
||||
torch.cuda.synchronize()
|
||||
self.remove_grad_hook()
|
||||
|
||||
for p in self.model.parameters():
|
||||
p.grad = None
|
||||
|
||||
GlobalRuntimeInfo.fwd_prefetch_event_map.clear()
|
||||
GlobalRuntimeInfo.bwd_prefetch_event_map.clear()
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
region = self.region_manager.get_region(p)
|
||||
region.copy_grad_to_region_slice(p, grad)
|
||||
if region.can_release:
|
||||
self.overflow_counter += region.has_inf_or_nan
|
||||
master_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(self.grad_offload_stream):
|
||||
GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream)
|
||||
region.move_grad_to_cpu()
|
||||
return empty_grad
|
||||
|
||||
def _cast_buffers(self):
|
||||
for buffer in self.model.buffers():
|
||||
buffer.data = buffer.cuda()
|
||||
|
||||
def parameters(self, recurse: bool = True):
|
||||
return self.model.parameters(recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
||||
return self.model.named_parameters(prefix, recurse)
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True):
|
||||
return self.model.named_buffers(prefix, recurse)
|
||||
|
||||
def named_children(self):
|
||||
return self.model.named_children()
|
||||
|
||||
def named_modules(self,
|
||||
memo: Optional[Set[torch.nn.Module]] = None,
|
||||
prefix: str = '',
|
||||
remove_duplicate: bool = True):
|
||||
return self.model.named_modules(memo, prefix, remove_duplicate)
|
|
@ -0,0 +1,49 @@
|
|||
from typing import Dict
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.fx import ColoTracer, is_compatible_with_meta
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
|
||||
from .region_manager import RegionManager
|
||||
from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass
|
||||
from .base_offload_module import BaseOffloadModule
|
||||
from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo
|
||||
|
||||
def memory_optimize(model: torch.nn.Module,
|
||||
inps: Dict[str, torch.Tensor],
|
||||
memory_budget: float = -1.0,
|
||||
solver_name: str = 'asyn'):
|
||||
|
||||
model = model.cpu().half()
|
||||
tracer = ColoTracer()
|
||||
assert is_compatible_with_meta()
|
||||
wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x
|
||||
meta_args = tree_map(wrap_fn, inps)
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(*meta_args.values())
|
||||
|
||||
region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget)
|
||||
region_manager._build_regions()
|
||||
GlobalRuntimeInfo.region_list = region_manager.region_list
|
||||
|
||||
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2
|
||||
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2
|
||||
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2
|
||||
print(
|
||||
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}")
|
||||
|
||||
if solver_name == 'syn':
|
||||
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
|
||||
elif solver_name == 'asyn':
|
||||
gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list)
|
||||
else:
|
||||
raise TypeError(f"Unknown solver name {solver_name}!")
|
||||
|
||||
gm.recompile()
|
||||
optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn')
|
||||
return optimized_model
|
|
@ -0,0 +1,144 @@
|
|||
from typing import List, Dict, Tuple
|
||||
import torch
|
||||
from torch.fx import Node
|
||||
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
|
||||
|
||||
class Region:
|
||||
"""
|
||||
Region: A container owning a piece of contiguous nodes in the DNN computing graph.
|
||||
|
||||
Args:
|
||||
r_id (int): the index of the region in the computing graph.
|
||||
"""
|
||||
|
||||
def __init__(self, r_id: int = 0) -> None:
|
||||
self.r_id: int = r_id
|
||||
self.fp16_params: List[torch.nn.Parameter] = []
|
||||
self.param_size: int = 0
|
||||
self.shared_rid: int = self.r_id
|
||||
|
||||
self.param_num: int = 0
|
||||
self.grad_num: int = 0
|
||||
self.fp16_data = None
|
||||
self.fp32_data = None
|
||||
self.cpu_grad = None
|
||||
self.temp_fp32_data = None
|
||||
self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict()
|
||||
|
||||
self.need_offload: bool = False
|
||||
self.is_syn: bool = False
|
||||
self.nodes: List[Node] = []
|
||||
self.fwd_prefetch_region = None
|
||||
self.bwd_prefetch_region = None
|
||||
|
||||
self.in_mem_pool_flag: bool = False
|
||||
|
||||
@property
|
||||
def can_release(self) -> bool:
|
||||
"""
|
||||
Check if the region can be released.
|
||||
"""
|
||||
return self.grad_num == self.param_num
|
||||
|
||||
@property
|
||||
def has_inf_or_nan(self) -> bool:
|
||||
"""
|
||||
Check if the grad of the region has inf or nan values on CUDA.
|
||||
"""
|
||||
return torch.isinf(self.fp16_data).any() | torch.isnan(self.fp16_data).any()
|
||||
|
||||
def init_param_data(self, pre_alloc_tensor: torch.Tensor = None):
|
||||
"""
|
||||
Map the parameters in the region to a contiguous memory space.
|
||||
"""
|
||||
|
||||
self.fp16_data = torch.zeros(
|
||||
self.param_num, dtype=torch.half, device='cuda')
|
||||
offset = 0
|
||||
for param in self.fp16_params:
|
||||
param.data = param.data.cuda()
|
||||
p_num = param.data.numel()
|
||||
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
|
||||
param.data = self.fp16_data[offset:offset +
|
||||
p_num].view(param.data.shape)
|
||||
self.param_to_range[param] = (offset, offset + p_num)
|
||||
offset += p_num
|
||||
|
||||
self.fp32_data = self.fp16_data.float().cpu().pin_memory()
|
||||
free_storage(self.fp16_data)
|
||||
if self.in_mem_pool_flag and pre_alloc_tensor is not None:
|
||||
self.fp16_data = pre_alloc_tensor
|
||||
|
||||
def move_param_to_cuda(self):
|
||||
"""
|
||||
Move parameters from CPU to GPU.
|
||||
It first moves float32 parameters to GPU and
|
||||
then transforms float32 parameters to half-precision on the GPU.
|
||||
The reason is that the performance of precision conversion on the CPU
|
||||
is much slower than the data transfer overhead.
|
||||
"""
|
||||
|
||||
self.temp_fp32_data.copy_(self.fp32_data, non_blocking=True)
|
||||
self.temp_fp32_data.record_stream(torch.cuda.current_stream())
|
||||
if not self.in_mem_pool_flag:
|
||||
alloc_storage(self.fp16_data)
|
||||
self.fp16_data[:self.param_num].copy_(self.temp_fp32_data)
|
||||
self.fp16_data.record_stream(torch.cuda.current_stream())
|
||||
|
||||
self.__update_params_ptr()
|
||||
|
||||
def move_grad_to_cpu(self):
|
||||
"""
|
||||
Move gradients from GPU to CPU.
|
||||
"""
|
||||
|
||||
self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True)
|
||||
self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True)
|
||||
self.fp16_data.record_stream(torch.cuda.current_stream())
|
||||
if not self.in_mem_pool_flag:
|
||||
self.free_cuda_data()
|
||||
|
||||
self.grad_num = 0
|
||||
|
||||
def free_cuda_data(self):
|
||||
free_storage(self.fp16_data)
|
||||
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
def copy_grad_to_region_slice(self, param: torch.nn.Parameter, data_slice: torch.Tensor) -> None:
|
||||
"""
|
||||
Copy data slice to the memory space indexed by the input tensor in the region.
|
||||
|
||||
Args:
|
||||
param (torch.nn.Parameter): the param used to retrive meta information
|
||||
data_slice (torch.Tensor): the tensor to be copied to the region
|
||||
"""
|
||||
|
||||
begin, end = self.param_to_range[param]
|
||||
self.fp16_data[begin:end].copy_(data_slice.data.flatten())
|
||||
param.data = self.fp16_data[begin:end].view(param.data.shape)
|
||||
|
||||
self.grad_num += data_slice.numel()
|
||||
|
||||
def split(self, cut_node_idx: int, cut_param_idx: int):
|
||||
"""
|
||||
Split the region into two and return the latter.
|
||||
"""
|
||||
new_reg = Region(r_id=self.r_id + 1)
|
||||
new_reg.nodes = self.nodes[cut_node_idx:]
|
||||
new_reg.fp16_params = self.fp16_params[cut_param_idx:]
|
||||
for p in new_reg.fp16_params:
|
||||
new_reg.param_size += p.data.numel() * p.data.element_size()
|
||||
new_reg.param_num += p.data.numel()
|
||||
|
||||
self.nodes = self.nodes[:cut_node_idx]
|
||||
self.fp16_params = self.fp16_params[:cut_param_idx]
|
||||
self.param_size -= new_reg.param_size
|
||||
self.param_num -= new_reg.param_num
|
||||
|
||||
return new_reg
|
||||
|
||||
def __update_params_ptr(self) -> None:
|
||||
for param in self.fp16_params:
|
||||
begin, end = self.param_to_range[param]
|
||||
param.data = self.fp16_data[begin:end].view(param.data.shape)
|
|
@ -0,0 +1,526 @@
|
|||
from typing import List, Any, Dict, Tuple
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from .solver import SolverFactory
|
||||
from .training_simulator import TrainingSimulator
|
||||
from .region import Region
|
||||
from .util import NodeInfo
|
||||
|
||||
|
||||
class RegionManager:
|
||||
"""
|
||||
RegionManager is used to construct and manage the offload plan for the model execution.
|
||||
|
||||
Args:
|
||||
graph (Graph): a Graph object used for analysis and strategy generation.
|
||||
solver_name (str): a solver name which specifies the preferences for plan searching.
|
||||
memory_budget (float): the given memory budget.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
solver_name: str = 'asyn',
|
||||
memory_budget: float = -1.0,
|
||||
cnode: List[str] = None):
|
||||
|
||||
self.graph = graph
|
||||
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||||
self.root_module = self.graph.owning_module
|
||||
self.nodes = list(graph.nodes)
|
||||
self.cnode = cnode
|
||||
self.only_param_ops = []
|
||||
self.param_region_map: Dict[torch.nn.Parameter, Region] = dict()
|
||||
self.shared_region_pairs: List[Tuple[Region, Region]] = list()
|
||||
self.region_list: List[Region] = list()
|
||||
self.rid_in_pool: List[int] = list()
|
||||
self.mem_block_size: int = 0
|
||||
self.memory_budget = memory_budget
|
||||
|
||||
self.solver_name = solver_name
|
||||
self.require_pool: bool = solver_name == 'asyn'
|
||||
|
||||
self.reg_to_block: Dict[int, int] = dict()
|
||||
|
||||
def _build_regions(self):
|
||||
"""
|
||||
1. Pre-processing, mainly contains linearized computing graph and
|
||||
merge smaller regions into larger ones.
|
||||
2. Construct a solver to search for an efficient offload strategy.
|
||||
3. Post-processing, mainly contains early region placement if using asynchronous mode,
|
||||
and initialize region data.
|
||||
"""
|
||||
|
||||
self._pre_process()
|
||||
|
||||
solver_cls = SolverFactory.create(self.solver_name)
|
||||
solver = solver_cls(self.region_list, self.memory_budget)
|
||||
solver._call_solver()
|
||||
|
||||
self._post_process(solver.best_ts)
|
||||
|
||||
def _pre_process(self):
|
||||
|
||||
init_region_list = self._linearize_graph()
|
||||
|
||||
if len(self.shared_region_pairs) > 1:
|
||||
raise NotImplementedError(
|
||||
'The current version only considers at most one pair of parameter sharing.')
|
||||
|
||||
elif len(self.shared_region_pairs) == 1:
|
||||
shared_regs = self.shared_region_pairs[0]
|
||||
assert shared_regs[0].shared_rid == shared_regs[1].r_id \
|
||||
and shared_regs[1].shared_rid == shared_regs[0].r_id
|
||||
fst_id = shared_regs[0].r_id
|
||||
lst_id = shared_regs[1].r_id
|
||||
regs_left_out = init_region_list[:fst_id + 1]
|
||||
regs_right_out = init_region_list[lst_id:]
|
||||
hold_regs = init_region_list[fst_id + 1:lst_id]
|
||||
else:
|
||||
regs_left_out = []
|
||||
regs_right_out = []
|
||||
hold_regs = init_region_list
|
||||
|
||||
self.mem_block_size = self._search_block_size(hold_regs)
|
||||
hold_regs = self._merge_small_regions(hold_regs)
|
||||
|
||||
if self.require_pool:
|
||||
for reg in hold_regs:
|
||||
reg.in_mem_pool_flag = True
|
||||
self.rid_in_pool.append(reg.r_id)
|
||||
|
||||
self.region_list.extend(regs_left_out)
|
||||
self.region_list.extend(hold_regs)
|
||||
|
||||
for reg in regs_right_out:
|
||||
reg.r_id = self.region_list[-1].r_id + 1
|
||||
self.region_list[reg.shared_rid].shared_rid = reg.r_id
|
||||
self.region_list.append(reg)
|
||||
|
||||
self._process_shared_region()
|
||||
|
||||
self.max_param_num = max([reg.param_num for reg in self.region_list])
|
||||
self.memory_budget -= self.max_param_num * torch.tensor([], dtype=torch.float32).element_size()
|
||||
|
||||
def _post_process(self, ts: TrainingSimulator = None):
|
||||
if self.require_pool:
|
||||
self._early_region_placement(ts)
|
||||
self._init_region_data()
|
||||
|
||||
def _early_region_placement(self, ts: TrainingSimulator):
|
||||
"""
|
||||
Implemented the early region placement strategy to avoid GPU memory fragmentation.
|
||||
It maps all region data into a contiguous memory space and
|
||||
reuses the same memory space for regions that do not coexist.
|
||||
|
||||
Args:
|
||||
ts (TrainingSimulator): the best training simulator, which records region execution flow.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: due to the naive implementation,
|
||||
it may not find a suitable region placement strategy for the given execution flow.
|
||||
"""
|
||||
|
||||
reg_flow = torch.cat(
|
||||
[ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
|
||||
mem_block_num = torch.max(
|
||||
torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
|
||||
coexist_matrix = torch.logical_or(
|
||||
ts.fwd_reg_flow, ts.bwd_reg_flow)
|
||||
|
||||
block_to_regs = {}
|
||||
for block_idx in range(mem_block_num):
|
||||
block_to_regs[block_idx] = []
|
||||
for reg in self.region_list:
|
||||
if reg.r_id in self.rid_in_pool:
|
||||
cur_reg_appears = coexist_matrix[:, reg.r_id]
|
||||
cur_reg_coexists = torch.sum(
|
||||
coexist_matrix[cur_reg_appears], dim=0).bool()
|
||||
for block_idx in range(mem_block_num):
|
||||
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
|
||||
block_to_regs[block_idx].append(reg.r_id)
|
||||
self.reg_to_block[reg.r_id] = block_idx
|
||||
break
|
||||
|
||||
if reg.r_id not in self.reg_to_block:
|
||||
raise NotImplementedError(
|
||||
f'can not find a block from the memory pool to store parameters of the region')
|
||||
self.memory_pool = torch.chunk(torch.zeros(int(
|
||||
mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
|
||||
|
||||
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
|
||||
"""
|
||||
Merge smaller regions into larger ones for better bandwidth utilization and easier management.
|
||||
It is inspired by Gemini.
|
||||
|
||||
Args:
|
||||
orig_reg_list (List[Region]): original region list.
|
||||
|
||||
Returns:
|
||||
List[Region]: region list after merging.
|
||||
"""
|
||||
|
||||
r_id = orig_reg_list[0].r_id
|
||||
region = Region(r_id=r_id)
|
||||
region_list = [region]
|
||||
|
||||
for orig_reg in orig_reg_list:
|
||||
if region_list[-1].param_size + orig_reg.param_size > self.mem_block_size:
|
||||
r_id += 1
|
||||
region = Region(r_id=r_id)
|
||||
region_list.append(region)
|
||||
region.param_size += orig_reg.param_size
|
||||
region.param_num += orig_reg.param_num
|
||||
region.nodes.extend(orig_reg.nodes)
|
||||
region.fp16_params.extend(orig_reg.fp16_params)
|
||||
self.__update_param_region_map(orig_reg.fp16_params, region)
|
||||
|
||||
return region_list
|
||||
|
||||
def _search_block_size(self,
|
||||
region_list: List[Region],
|
||||
search_interval_byte: int = 1024,
|
||||
search_range_byte: int = 128 * 1024 ** 2) -> int:
|
||||
"""
|
||||
Search for a suitable memory block size.
|
||||
|
||||
Args:
|
||||
region_list (List[Region]): region list.
|
||||
search_interval_byte (int): searching interval in byte.
|
||||
search_range_byte (int): searching range in byte.
|
||||
|
||||
Returns:
|
||||
int: the best memory block size.
|
||||
"""
|
||||
|
||||
def _get_wasted_mem(size_list: List[int], blk_size: int):
|
||||
"""
|
||||
Get wasted byte for a certain block size.
|
||||
"""
|
||||
acc_wasted = 0
|
||||
left = 0
|
||||
for s in size_list:
|
||||
if left + s > blk_size:
|
||||
acc_wasted += blk_size - left
|
||||
left = s
|
||||
left += s
|
||||
acc_wasted += blk_size - left
|
||||
return acc_wasted
|
||||
|
||||
param_size_list = [
|
||||
region.param_size for region in region_list if region.r_id == region.shared_rid]
|
||||
|
||||
start_size = max(param_size_list)
|
||||
min_mem_waste = float('+inf')
|
||||
best_block_size = start_size
|
||||
|
||||
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
|
||||
temp_waste = 0
|
||||
temp_waste += _get_wasted_mem(param_size_list, block_size)
|
||||
if temp_waste < min_mem_waste:
|
||||
min_mem_waste = temp_waste
|
||||
best_block_size = block_size
|
||||
|
||||
return best_block_size
|
||||
|
||||
def _init_region_data(self):
|
||||
"""
|
||||
Initialize region data, which maps the parameters in the region to a contiguous memory space.
|
||||
"""
|
||||
|
||||
self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32)
|
||||
|
||||
for region in self.region_list:
|
||||
pre_alloc_tensor = None
|
||||
if self.require_pool and region.r_id in self.rid_in_pool:
|
||||
block_idx = self.reg_to_block[region.r_id]
|
||||
pre_alloc_tensor = self.memory_pool[block_idx]
|
||||
|
||||
if region.r_id <= region.shared_rid:
|
||||
region.init_param_data(pre_alloc_tensor)
|
||||
else:
|
||||
shared_region = self.region_list[region.shared_rid]
|
||||
region.fp16_data = shared_region.fp16_data
|
||||
region.fp32_data = shared_region.fp32_data
|
||||
region.param_to_range = shared_region.param_to_range
|
||||
region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _process_shared_region(self):
|
||||
"""
|
||||
Special processing for the shared region, which uses GPT2 and Bert case as a priori knowledge.
|
||||
"""
|
||||
|
||||
if len(self.shared_region_pairs):
|
||||
assert len(self.shared_region_pairs) <= 1
|
||||
former_reg, latter_reg = self.shared_region_pairs[0]
|
||||
assert latter_reg.param_num >= former_reg.param_num
|
||||
embedding_node = former_reg.nodes[-1]
|
||||
assert embedding_node.op == 'call_module' and isinstance(
|
||||
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding)
|
||||
if latter_reg.param_num > former_reg.param_num:
|
||||
for idx, n in enumerate(latter_reg.nodes):
|
||||
if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target),
|
||||
torch.nn.Linear)) or \
|
||||
(n.op == 'call_function' and n.target is torch.nn.functional.linear):
|
||||
cut_node_idx = idx + 1
|
||||
break
|
||||
assert len(latter_reg.fp16_params) == 2
|
||||
new_reg = latter_reg.split(cut_node_idx, 1)
|
||||
for p in new_reg.fp16_params:
|
||||
self.param_region_map[p] = new_reg
|
||||
self.region_list.insert(new_reg.r_id, new_reg)
|
||||
for reg in self.region_list[new_reg.r_id + 1:]:
|
||||
reg.r_id += 1
|
||||
latter_reg.shared_rid = former_reg.r_id
|
||||
former_reg.shared_rid = latter_reg.r_id
|
||||
|
||||
def _linearize_graph(self) -> List[Region]:
|
||||
"""Linearizing the graph
|
||||
|
||||
Args:
|
||||
graph (Graph): The computing graph to be optimized.
|
||||
|
||||
Returns:
|
||||
List[Region]: each region contains the actual 'node' in linearized manner.
|
||||
|
||||
Remarks:
|
||||
Do merge the inplace ops and shape-consistency ops into the previous node.
|
||||
"""
|
||||
|
||||
# List of target name that could be seen as common node
|
||||
common_ops = ["getattr", "getitem", "size"]
|
||||
|
||||
def _is_cop(target: Any) -> bool:
|
||||
"""Check if an op could be seen as common node
|
||||
|
||||
Args:
|
||||
target (Any): node target
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
if isinstance(target, str):
|
||||
return target in common_ops
|
||||
else:
|
||||
return target.__name__ in common_ops
|
||||
|
||||
def _is_act(data: Any) -> bool:
|
||||
"""Check if an op could be seen as parameter computation start
|
||||
|
||||
Args:
|
||||
data (Any): meta_data
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
label = False
|
||||
if isinstance(data, torch.Tensor):
|
||||
return True
|
||||
elif isinstance(data, (tuple, list)):
|
||||
for d in data:
|
||||
label = label or _is_act(d)
|
||||
return label
|
||||
|
||||
def _maybe_param_comp_start() -> bool:
|
||||
"""Check if an op could be seen as parameter computation start
|
||||
|
||||
Args:
|
||||
n (Node): node
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
label = False
|
||||
if n.op == "get_attr":
|
||||
label = True
|
||||
elif n.op == "call_module":
|
||||
target = n.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
if (
|
||||
len(list(submod.named_parameters(recurse=False))) != 0
|
||||
or len(list(submod.named_buffers(recurse=False))) != 0
|
||||
):
|
||||
label = True
|
||||
|
||||
return label and not sum([v for _, v in param_op_deps.items()])
|
||||
|
||||
def _is_param_comp_end() -> bool:
|
||||
"""Check if an op could be seen as parameter computation end
|
||||
|
||||
Args:
|
||||
n (Node): node
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
def _is_inplace(n: Node):
|
||||
"""Get the inplace argument from ``torch.fx.Node``
|
||||
"""
|
||||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
elif n.op == "call_module":
|
||||
inplace = getattr(n.graph.owning_module.get_submodule(
|
||||
n.target), "inplace", False)
|
||||
return inplace
|
||||
|
||||
label = False
|
||||
|
||||
if n.op == "call_module":
|
||||
target = n.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
if (
|
||||
len(list(submod.named_parameters(recurse=False))) != 0
|
||||
or len(list(submod.named_buffers(recurse=False))) != 0
|
||||
):
|
||||
label = True
|
||||
|
||||
elif n.op == "call_function":
|
||||
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
|
||||
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes))
|
||||
|
||||
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
|
||||
|
||||
def _exception_node_handling():
|
||||
# TODO meta info prop bug
|
||||
if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2:
|
||||
n.meta['fwd_out'] = []
|
||||
|
||||
# make sure that item in cnode is valid
|
||||
if self.cnode:
|
||||
for name in self.cnode:
|
||||
try:
|
||||
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
|
||||
f"Common node {name} is not an input of the model."
|
||||
except StopIteration:
|
||||
raise ValueError(f"Common node name {name} not in graph.")
|
||||
else:
|
||||
self.cnode = []
|
||||
|
||||
node_id = 0
|
||||
region_id = 0
|
||||
|
||||
param_op_deps = {}
|
||||
|
||||
deps = {}
|
||||
region_list = []
|
||||
region = Region(r_id=region_id)
|
||||
|
||||
act_n = None
|
||||
|
||||
for n in self.graph.nodes:
|
||||
if n.op != "placeholder" and n.op != "output":
|
||||
for n_par in n.all_input_nodes:
|
||||
if n_par.op != "placeholder" and n_par.name not in self.cnode:
|
||||
deps[n_par] -= 1
|
||||
if n_par.op != "placeholder" and n_par.name in self.only_param_ops:
|
||||
param_op_deps[n_par] -= 1
|
||||
|
||||
if act_n in region.nodes and _maybe_param_comp_start():
|
||||
ns = []
|
||||
border_n_idx = region.nodes.index(act_n)
|
||||
if border_n_idx < len(region.nodes):
|
||||
ns = region.nodes[border_n_idx + 1:]
|
||||
region.nodes = region.nodes[:border_n_idx + 1]
|
||||
region_list.append(region)
|
||||
region_id += 1
|
||||
region = Region(r_id=region_id)
|
||||
region.nodes = ns
|
||||
|
||||
_exception_node_handling()
|
||||
region.nodes.append(n)
|
||||
self._set_node_and_region_info(node_id, n, region)
|
||||
node_id += 1
|
||||
|
||||
# if the node could free all dependencies in graph
|
||||
# we could begin a new region
|
||||
if _is_param_comp_end():
|
||||
region_list.append(region)
|
||||
region_id += 1
|
||||
region = Region(r_id=region_id)
|
||||
|
||||
# propagate common node attr if possible
|
||||
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
|
||||
]) or _is_cop(n.target):
|
||||
self.cnode.append(n.name)
|
||||
else:
|
||||
deps[n] = len(
|
||||
[user for user in n.users if user.op != "output"])
|
||||
|
||||
# propagate param node attr if possible
|
||||
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
|
||||
]) or n.op == "get_attr":
|
||||
self.only_param_ops.append(n.name)
|
||||
param_op_deps[n] = len(
|
||||
[user for user in n.users if user.op != "output"])
|
||||
|
||||
# record last activation node
|
||||
if _is_act(n._meta_data):
|
||||
act_n = n
|
||||
|
||||
if len(region.nodes):
|
||||
region_list.append(region)
|
||||
|
||||
return region_list
|
||||
|
||||
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
|
||||
|
||||
cur_n.node_info = NodeInfo(node_id)
|
||||
|
||||
if cur_n.op == 'call_module':
|
||||
target = cur_n.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
for p in list(submod.parameters(recurse=False)):
|
||||
|
||||
if p in self.param_region_map:
|
||||
cur_reg.shared_rid = self.param_region_map[p].r_id
|
||||
self.param_region_map[p].shared_rid = cur_reg.r_id
|
||||
self.shared_region_pairs.append(
|
||||
(self.param_region_map[p], cur_reg))
|
||||
else:
|
||||
self.param_region_map[p] = cur_reg
|
||||
|
||||
cur_reg.fp16_params.append(p)
|
||||
cur_reg.param_num += p.data.numel()
|
||||
cur_reg.param_size += p.data.numel() * p.data.element_size()
|
||||
|
||||
elif cur_n.op == "get_attr":
|
||||
attr_itr = self.root_module
|
||||
atoms = cur_n.target.split(".")
|
||||
for atom in atoms:
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
|
||||
if isinstance(attr_itr, torch.nn.Parameter):
|
||||
|
||||
if attr_itr in self.param_region_map:
|
||||
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
|
||||
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
|
||||
self.shared_region_pairs.append(
|
||||
(self.param_region_map[attr_itr], cur_reg))
|
||||
else:
|
||||
self.param_region_map[attr_itr] = cur_reg
|
||||
|
||||
cur_reg.fp16_params.append(attr_itr)
|
||||
cur_reg.param_num += attr_itr.data.numel()
|
||||
cur_reg.param_size += attr_itr.data.numel() * attr_itr.data.element_size()
|
||||
|
||||
def get_region(self, param: torch.nn.Parameter) -> Region:
|
||||
"""
|
||||
Return the region owning the parameter.
|
||||
|
||||
Args:
|
||||
param (torch.nn.Parameter): a torch parameter object
|
||||
"""
|
||||
return self.param_region_map[param]
|
||||
|
||||
def __update_param_region_map(self, params: List[torch.nn.Parameter], region: Region):
|
||||
for p in params:
|
||||
self.param_region_map[p] = region
|
|
@ -0,0 +1,253 @@
|
|||
from typing import List
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .region import Region
|
||||
from .util import GlobalRuntimeInfo, requires_upload_p_in_fwd
|
||||
|
||||
|
||||
class SynPreFwdPostBwdOP(torch.autograd.Function):
|
||||
"""
|
||||
A customized prefetch and offload operation.
|
||||
|
||||
Args:
|
||||
input_: input tensor.
|
||||
fwd_info: information dict, which contains region indices
|
||||
that need to be uploaded or freed during forward pass.
|
||||
bwd_info: information dict, which contains region indices
|
||||
that need to be uploaded during backward pass.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, fwd_info, bwd_info):
|
||||
ctx.bwd_info = bwd_info
|
||||
d2h_rid = fwd_info.get('d2h_rid', None)
|
||||
if d2h_rid is not None:
|
||||
free_region = GlobalRuntimeInfo.region_list[d2h_rid]
|
||||
assert isinstance(free_region, Region)
|
||||
free_region.free_cuda_data()
|
||||
|
||||
h2d_rid = fwd_info.get('h2d_rid', None)
|
||||
if h2d_rid is not None:
|
||||
h2d_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
||||
assert isinstance(h2d_region, Region)
|
||||
h2d_region.move_param_to_cuda()
|
||||
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
|
||||
if h2d_rid is not None:
|
||||
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
||||
assert isinstance(pref_region, Region)
|
||||
pref_region.move_param_to_cuda()
|
||||
|
||||
return grad_output, None, None
|
||||
|
||||
|
||||
class AsynPreFwdPostBwdOP(torch.autograd.Function):
|
||||
"""
|
||||
A customized prefetch and offload operation.
|
||||
|
||||
Args:
|
||||
input_: input tensor.
|
||||
fwd_info: information dict, which contains region indices
|
||||
that need to be prefetched, waited, or freed during forward pass.
|
||||
bwd_info: information dict, which contains region indices
|
||||
that need to be prefetched or waited during backward pass.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, fwd_info, bwd_info):
|
||||
ctx.bwd_info = bwd_info
|
||||
|
||||
sync_rid = fwd_info.get('sync_rid', None)
|
||||
if sync_rid is not None:
|
||||
prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get(
|
||||
sync_rid, None)
|
||||
if prefetch_event:
|
||||
prefetch_event.wait()
|
||||
|
||||
h2d_rid = fwd_info.get('h2d_rid', None)
|
||||
if h2d_rid is not None:
|
||||
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
||||
assert isinstance(pref_region, Region)
|
||||
master_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
|
||||
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
|
||||
pref_region.move_param_to_cuda()
|
||||
|
||||
prefetch_event = torch.cuda.Event()
|
||||
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
|
||||
GlobalRuntimeInfo.fwd_prefetch_event_map[h2d_rid] = prefetch_event
|
||||
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
sync_rid = ctx.bwd_info.get('sync_rid', None)
|
||||
if sync_rid is not None:
|
||||
wait_region = GlobalRuntimeInfo.region_list[sync_rid]
|
||||
assert isinstance(wait_region, Region)
|
||||
prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get(
|
||||
sync_rid, None)
|
||||
if prefetch_event:
|
||||
prefetch_event.wait()
|
||||
else:
|
||||
wait_region.move_param_to_cuda()
|
||||
|
||||
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
|
||||
if h2d_rid is not None:
|
||||
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
||||
assert isinstance(pref_region, Region)
|
||||
master_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
|
||||
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
|
||||
pref_region.move_param_to_cuda()
|
||||
|
||||
prefetch_event = torch.cuda.Event()
|
||||
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
|
||||
GlobalRuntimeInfo.bwd_prefetch_event_map[h2d_rid] = prefetch_event
|
||||
return grad_output, None, None
|
||||
|
||||
|
||||
def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
|
||||
'''
|
||||
Convert Upload and Offload operation into runtime action.
|
||||
|
||||
Argument:
|
||||
tensor(torch.Tensor): input tensor.
|
||||
fwd_info(dict): information dict, which contains region indices
|
||||
that need to be uploaded, or freed during forward pass.
|
||||
bwd_info(dict): information dict, which contains region indices
|
||||
that need to be uploaded during backward pass.
|
||||
'''
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
|
||||
return ret
|
||||
|
||||
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
|
||||
'''
|
||||
Convert Prefetch and Offload operation into runtime action.
|
||||
|
||||
Argument:
|
||||
tensor(torch.Tensor): input tensor.
|
||||
fwd_info(dict): information dict, which contains region indices
|
||||
that need to be prefetched, waited, or freed during forward pass.
|
||||
bwd_info(dict): information dict, which contains region indices
|
||||
that need to be prefetched or waited during backward pass.
|
||||
'''
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
|
||||
return ret
|
||||
|
||||
|
||||
def replace_node_users(orig_node: Node, inserted_node: Node, rep_user_nodes: List[Node] = None):
|
||||
user_list = list(orig_node.users.keys())
|
||||
if rep_user_nodes is not None:
|
||||
user_list = rep_user_nodes
|
||||
for user in user_list:
|
||||
if user == inserted_node:
|
||||
continue
|
||||
new_args = list(user.args)
|
||||
new_kwargs = dict(user.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if orig_node in new_args:
|
||||
# substitute the origin node with offload_apply_node
|
||||
new_args[new_args.index(orig_node)] = inserted_node
|
||||
user.args = tuple(new_args)
|
||||
elif str(orig_node) in new_kwargs:
|
||||
# substitute the origin node with offload_apply_node
|
||||
new_kwargs[str(orig_node)] = inserted_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
|
||||
def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]):
|
||||
"""
|
||||
This pass is used to add the synchronous upload and offload spec apply node to the origin graph.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
last_inp_node = tuple(mod_graph.nodes)[0]
|
||||
|
||||
for r_idx, region in enumerate(region_list):
|
||||
# forward upload
|
||||
fwd_info = {}
|
||||
if requires_upload_p_in_fwd(region_list[region.shared_rid]):
|
||||
fwd_info['h2d_rid'] = region.r_id
|
||||
|
||||
# forward offload
|
||||
if r_idx > 0 and region_list[r_idx - 1].need_offload:
|
||||
fwd_info['d2h_rid'] = r_idx - 1
|
||||
|
||||
bwd_info = {}
|
||||
# backward upload
|
||||
if r_idx > 0 and region_list[r_idx - 1].need_offload:
|
||||
bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id
|
||||
|
||||
if fwd_info or bwd_info:
|
||||
with mod_graph.inserting_after(last_inp_node):
|
||||
new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
|
||||
args=(last_inp_node, fwd_info, bwd_info))
|
||||
replace_node_users(last_inp_node, new_node)
|
||||
|
||||
last_inp_node = region.nodes[-1]
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]):
|
||||
"""
|
||||
This pass is used to add the asynchronous prefetch and offload spec apply node to the origin graph.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
|
||||
# upload parameters of the first region
|
||||
last_inp_node = tuple(mod_graph.nodes)[0]
|
||||
first_region_with_p = [
|
||||
region for region in region_list if region.param_size][0]
|
||||
fwd_info = {"h2d_rid": first_region_with_p.r_id}
|
||||
with mod_graph.inserting_after(last_inp_node):
|
||||
upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
|
||||
args=(last_inp_node, fwd_info, {}))
|
||||
replace_node_users(last_inp_node, upload_apply_node)
|
||||
last_inp_node = upload_apply_node
|
||||
|
||||
for r_idx, region in enumerate(region_list):
|
||||
# forward prefetch
|
||||
fwd_info = {}
|
||||
if region.param_size:
|
||||
fwd_info['sync_rid'] = region.r_id
|
||||
fwd_prefetch_region = region.fwd_prefetch_region
|
||||
if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]):
|
||||
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
|
||||
|
||||
# forward offload
|
||||
if r_idx > 0 and region_list[r_idx-1].need_offload:
|
||||
fwd_info['d2h_rid'] = r_idx - 1
|
||||
|
||||
bwd_info = {}
|
||||
# backward prefetch
|
||||
if r_idx > 0 and region_list[r_idx-1].need_offload:
|
||||
bwd_info['sync_rid'] = r_idx - 1
|
||||
if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region:
|
||||
bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id
|
||||
|
||||
if fwd_info or bwd_info:
|
||||
with mod_graph.inserting_after(last_inp_node):
|
||||
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
|
||||
args=(last_inp_node, fwd_info, bwd_info))
|
||||
replace_node_users(last_inp_node, new_node)
|
||||
|
||||
last_inp_node = region.nodes[-1]
|
||||
|
||||
if region.bwd_prefetch_region:
|
||||
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
|
||||
with mod_graph.inserting_after(last_inp_node):
|
||||
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
|
||||
args=(last_inp_node, {}, bwd_info))
|
||||
replace_node_users(last_inp_node, new_node)
|
||||
# gm.graph.print_tabular()
|
||||
return gm
|
|
@ -0,0 +1,523 @@
|
|||
import time
|
||||
from typing import List, Dict, Type
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
NOT_NVML = False
|
||||
try:
|
||||
from pynvml import *
|
||||
except:
|
||||
NOT_NVML = True
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator
|
||||
from .region import Region
|
||||
from .util import NodeInfo, NvDevicePower
|
||||
|
||||
|
||||
def benchmark_func(func, number=1, repeat=1, warmup=3):
|
||||
"""
|
||||
benchmark data transfer cost.
|
||||
"""
|
||||
|
||||
for i in range(warmup):
|
||||
func()
|
||||
|
||||
costs = []
|
||||
|
||||
for i in range(repeat):
|
||||
torch.cuda.synchronize()
|
||||
begin = time.time()
|
||||
for i in range(number):
|
||||
func()
|
||||
torch.cuda.synchronize()
|
||||
costs.append((time.time() - begin) / number)
|
||||
|
||||
return sum(costs) / len(costs)
|
||||
|
||||
|
||||
class Solver(ABC):
|
||||
"""
|
||||
The parameter offload solver.
|
||||
|
||||
Args:
|
||||
region_list (List[Region]): represents the linearized DNN computing graph.
|
||||
memory_budget (float): the given memory budget.
|
||||
error_factor (float): the error factor.
|
||||
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
region_list: List[Region],
|
||||
memory_budget: float = -1.0,
|
||||
error_factor: float = 0.95) -> None:
|
||||
|
||||
self.region_list = region_list
|
||||
|
||||
self.error_factor: float = error_factor
|
||||
if memory_budget > 0:
|
||||
self.memory_budget = memory_budget * self.error_factor
|
||||
else:
|
||||
self.memory_budget = torch.cuda.get_device_properties(
|
||||
get_current_device()).total_memory * self.error_factor
|
||||
|
||||
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
|
||||
self.comp_power: float = self._extract_computing_power()
|
||||
|
||||
@abstractmethod
|
||||
def _call_solver(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _try_to_offload(self, *args):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _eval_one_choice(self, *args):
|
||||
raise NotImplementedError
|
||||
|
||||
def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: float, extra_cost: float):
|
||||
"""
|
||||
Compute the profits of the offload strategies,
|
||||
which packages the memory savings information for subsequent comparisons.
|
||||
|
||||
Args:
|
||||
total_mem_saving (float): the total memory saving of the offload strategy.
|
||||
peak_mem_saving (float): the peak memory saving of the offload strategy.
|
||||
extra_cost (float): extra data transfer cost.
|
||||
|
||||
Returns:
|
||||
tuple: profit information, the first term represents memory savings per unit of time.
|
||||
"""
|
||||
|
||||
if extra_cost == 0:
|
||||
# means data transfer overhead can be completely overlapped
|
||||
return (float('inf'), total_mem_saving, peak_mem_saving)
|
||||
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
|
||||
|
||||
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
|
||||
"""
|
||||
Compare the profits of the two offload strategies using the dictionary order algorithm.
|
||||
|
||||
Args:
|
||||
profit_a (tuple): the profit of a offload strategy.
|
||||
profit_b (tuple): the profit of another offload strategy.
|
||||
|
||||
Returns:
|
||||
bool: whether profit_a is greater than profit_b.
|
||||
"""
|
||||
|
||||
for val1, val2 in zip(profit_a, profit_b):
|
||||
if val1 != val2:
|
||||
return val1 > val2
|
||||
return False
|
||||
|
||||
def _update_state(self, best_ts: TrainingSimulator):
|
||||
"""
|
||||
Update the solver state.
|
||||
"""
|
||||
|
||||
self.best_ts = best_ts
|
||||
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
|
||||
|
||||
def _update_node_mem_info(self,
|
||||
fwd_mem_info: Dict[Node, float],
|
||||
bwd_mem_info: Dict[Node, float]):
|
||||
"""
|
||||
Update the runtime memory information of the node.
|
||||
|
||||
Args:
|
||||
fwd_mem_info (Dict[Node, float]): the runtime memory of each node in forward pass.
|
||||
bwd_mem_info (Dict[Node, float]): the runtime memory of each node in backward pass.
|
||||
"""
|
||||
|
||||
for node, mem in fwd_mem_info.items():
|
||||
assert hasattr(node, 'node_info') and isinstance(
|
||||
node.node_info, NodeInfo)
|
||||
node.node_info.runtime_fwd_mem = mem
|
||||
for node, mem in bwd_mem_info.items():
|
||||
assert hasattr(node, 'node_info') and isinstance(
|
||||
node.node_info, NodeInfo)
|
||||
node.node_info.runtime_bwd_mem = mem
|
||||
|
||||
def _extract_computing_power(self):
|
||||
"""
|
||||
return the FP16 computing performance of the current NVIDIA GPU.
|
||||
|
||||
Raises:
|
||||
TypeError: Unknown NVIDIA GPU device.
|
||||
"""
|
||||
|
||||
nvmlInit()
|
||||
handle = nvmlDeviceGetHandleByIndex(0)
|
||||
device_name = nvmlDeviceGetName(handle)
|
||||
units = 1e12
|
||||
|
||||
if device_name.__contains__("RTX 3080"):
|
||||
return NvDevicePower.RTX3080_FP16 * units
|
||||
elif device_name.__contains__("RTX 3090"):
|
||||
return NvDevicePower.RTX3090_FP16 * units
|
||||
elif device_name.__contains__('V100'):
|
||||
return NvDevicePower.V100_FP16 * units
|
||||
elif device_name.__contains__("A100"):
|
||||
return NvDevicePower.A100_FP16 * units
|
||||
else:
|
||||
raise TypeError(f'Unknown NVIDIA GPU device name {device_name}')
|
||||
|
||||
def _profile_bandwidth(self):
|
||||
"""
|
||||
Profile the bidirectional communication bandwidth between CPU and GPU
|
||||
using data volumes ranging from 1KB to 1GB.
|
||||
"""
|
||||
|
||||
print('profiling bandwidth ......')
|
||||
link_to_bandwidth = {}
|
||||
links = ['h2d', 'd2h']
|
||||
|
||||
for link in links:
|
||||
t_size = 1024
|
||||
size_to_bandwidth = {}
|
||||
|
||||
# from 1KB to 1GB
|
||||
for i in range(21):
|
||||
if link == 'h2d':
|
||||
src_tensor = torch.ones(
|
||||
int(t_size), dtype=torch.int8, pin_memory=True)
|
||||
dst_tensor = torch.ones(
|
||||
(int(t_size)), dtype=torch.int8, device='cuda')
|
||||
elif link == 'd2h':
|
||||
src_tensor = torch.ones(
|
||||
int(t_size), dtype=torch.int8, device='cuda')
|
||||
dst_tensor = torch.ones(
|
||||
(int(t_size)), dtype=torch.int8, pin_memory=True)
|
||||
|
||||
def func():
|
||||
dst_tensor.copy_(src_tensor)
|
||||
|
||||
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
|
||||
print(f'size: {t_size / 1024 ** 2:.3f} MB, '
|
||||
f'{src_tensor.device.type}-to-{dst_tensor.device.type} '
|
||||
f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s')
|
||||
|
||||
t_size *= 2
|
||||
|
||||
link_to_bandwidth[link] = size_to_bandwidth
|
||||
return link_to_bandwidth
|
||||
|
||||
|
||||
class SynGreedySolver(Solver):
|
||||
|
||||
def __init__(self,
|
||||
region_list: List[Region],
|
||||
memory_budget: float = -1.0) -> None:
|
||||
super().__init__(region_list, memory_budget)
|
||||
|
||||
self.best_ts: SynTrainingSimulator = None
|
||||
self._init_state()
|
||||
|
||||
def _init_state(self):
|
||||
"""
|
||||
Initialize the solver state when without offloading.
|
||||
"""
|
||||
|
||||
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
|
||||
ts.execute()
|
||||
self._update_state(ts)
|
||||
|
||||
def _call_solver(self):
|
||||
"""
|
||||
Call the solver to search an efficient parameter offloading strategy for the linearized graph.
|
||||
The solver adopts greedy algorithm.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Unable to find a solution for the given memory budget.
|
||||
"""
|
||||
|
||||
print("search offloading strategy ......")
|
||||
while self.best_ts.peak_mem > self.memory_budget:
|
||||
offload_region = None
|
||||
best_ts = None
|
||||
max_profit = (0,)
|
||||
|
||||
# search which region should be offloaded,
|
||||
# the last region does not need to be offloaded.
|
||||
for region in self.region_list[:-1]:
|
||||
if region.param_size and not region.need_offload:
|
||||
temp_ts, profit = self._try_to_offload(region)
|
||||
if self._compare_profit(profit, max_profit):
|
||||
offload_region = region
|
||||
max_profit = profit
|
||||
best_ts = temp_ts
|
||||
|
||||
if offload_region is not None and best_ts is not None:
|
||||
offload_region.need_offload = True
|
||||
offload_region.is_syn = True
|
||||
self._update_state(best_ts)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
|
||||
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
|
||||
|
||||
def _call_solver_l2l(self):
|
||||
"""
|
||||
The layer-wise offload strategy.
|
||||
"""
|
||||
|
||||
for region in self.region_list[:-1]:
|
||||
region.need_offload = True
|
||||
region.is_syn = True
|
||||
|
||||
def _try_to_offload(self, offload_region: Region):
|
||||
|
||||
# record previous information
|
||||
orig_need_offload = offload_region.need_offload
|
||||
assert not orig_need_offload
|
||||
offload_region.need_offload = True
|
||||
|
||||
ts, profit = self._eval_one_choice(offload_region)
|
||||
|
||||
# restore previous information
|
||||
offload_region.need_offload = orig_need_offload
|
||||
return ts, profit
|
||||
|
||||
def _eval_one_choice(self, offload_region: Region):
|
||||
"""
|
||||
Evaluate the profit of a strategy choice.
|
||||
|
||||
Args:
|
||||
offload_region (Region): the offload region of current choice.
|
||||
|
||||
Returns:
|
||||
SynTrainingSimulator: the training simulator corresponding to the current strategy.
|
||||
tuple: contains memory saving and cost information of the current strategy.
|
||||
"""
|
||||
|
||||
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
|
||||
ts.execute()
|
||||
|
||||
extra_comm_cost = 2.0 * \
|
||||
ts._get_communication_overhead('h2d', offload_region.param_size)
|
||||
# the shared region needs to be moved twice
|
||||
if offload_region.r_id < offload_region.shared_rid:
|
||||
extra_comm_cost *= 2.0
|
||||
profit = self._compute_offload_profit(
|
||||
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
|
||||
|
||||
return ts, profit
|
||||
|
||||
|
||||
class AsynGreedySolver(Solver):
|
||||
|
||||
def __init__(self,
|
||||
region_list: List[Region],
|
||||
memory_budget: float = -1.0,
|
||||
search_window_size: int = 3):
|
||||
super().__init__(region_list, memory_budget)
|
||||
|
||||
self.search_window_size = search_window_size
|
||||
# Records the prefetch execution location of the offloaded region
|
||||
self.region_to_region_map = {}
|
||||
self.best_ts: AsynTrainingSimulator = None
|
||||
|
||||
self._init_state()
|
||||
|
||||
def _init_state(self):
|
||||
"""
|
||||
Initialize the solver state when without offloading.
|
||||
"""
|
||||
|
||||
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
|
||||
ts.execute()
|
||||
self._update_state(ts)
|
||||
print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB")
|
||||
|
||||
def _call_solver(self):
|
||||
"""
|
||||
Call the solver to search an efficient parameter offloading strategy for the linearized graph.
|
||||
The solver adopts greedy algorithm.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Unable to find a solution for the given memory budget.
|
||||
"""
|
||||
|
||||
print("search for offloading strategy ......")
|
||||
# Records the prefetch execution location of the offloaded region
|
||||
region_to_region_map = {}
|
||||
while self.best_ts.peak_mem > self.memory_budget:
|
||||
region_to_offload = None
|
||||
max_offload_profit = (0,)
|
||||
best_offl_ts = None
|
||||
|
||||
# search which region should be offloaded,
|
||||
# the last region does not need to be offloaded
|
||||
for region in self.region_list[:-1]:
|
||||
if region.param_size and not region.need_offload:
|
||||
max_prefetch_profit = (0,)
|
||||
best_pref_ts = None
|
||||
|
||||
# search when to prefetch the region offloaded
|
||||
for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]:
|
||||
if host_region.bwd_prefetch_region is not None:
|
||||
continue
|
||||
|
||||
temp_ts, profit = self._try_to_offload(
|
||||
host_region, region)
|
||||
|
||||
if self._compare_profit(profit, max_prefetch_profit):
|
||||
region_to_region_map[region.r_id] = host_region
|
||||
max_prefetch_profit = profit
|
||||
best_pref_ts = temp_ts
|
||||
if profit[0] == float('inf'):
|
||||
break
|
||||
|
||||
if self._compare_profit(max_prefetch_profit, max_offload_profit):
|
||||
region_to_offload = region
|
||||
max_offload_profit = max_prefetch_profit
|
||||
best_offl_ts = best_pref_ts
|
||||
|
||||
if (region_to_offload is not None) and (best_offl_ts is not None):
|
||||
region_to_offload.need_offload = True
|
||||
if region_to_region_map[region_to_offload.r_id] == region_to_offload:
|
||||
region_to_offload.is_syn = True
|
||||
else:
|
||||
region_to_region_map[region_to_offload.r_id].bwd_prefetch_region = region_to_offload
|
||||
self.region_to_region_map[region_to_offload.r_id] = region_to_region_map[region_to_offload.r_id]
|
||||
|
||||
self._update_state(best_offl_ts)
|
||||
|
||||
elif self.region_to_region_map.__len__() > 0:
|
||||
self._repair_strategy()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
|
||||
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
|
||||
|
||||
region_to_region_map.clear()
|
||||
|
||||
def _try_to_offload(self, host_region: Region, offload_region: Region):
|
||||
"""
|
||||
Attempts to offload the region and prefetch it in backward pass.
|
||||
"""
|
||||
|
||||
# record previous information
|
||||
orig_prefetch = host_region.bwd_prefetch_region
|
||||
orig_is_syn = offload_region.is_syn
|
||||
orig_need_offload = offload_region.need_offload
|
||||
|
||||
if host_region == offload_region:
|
||||
offload_region.is_syn = True
|
||||
else:
|
||||
host_region.bwd_prefetch_region = offload_region
|
||||
offload_region.need_offload = True
|
||||
|
||||
ts, profit = self._eval_one_choice()
|
||||
|
||||
# restore previous information
|
||||
host_region.bwd_prefetch_region = orig_prefetch
|
||||
offload_region.is_syn = orig_is_syn
|
||||
offload_region.need_offload = orig_need_offload
|
||||
|
||||
return ts, profit
|
||||
|
||||
def _try_convert_to_syn_upload(self, host_region: Region, offload_region: Region):
|
||||
"""
|
||||
Attempts to convert asynchronous prefetch into synchronous upload operations.
|
||||
"""
|
||||
|
||||
# record previous information
|
||||
orig_prefetch = host_region.bwd_prefetch_region
|
||||
orig_is_syn = offload_region.is_syn
|
||||
assert orig_prefetch is not None and not orig_is_syn
|
||||
|
||||
host_region.bwd_prefetch_region = None
|
||||
offload_region.is_syn = True
|
||||
|
||||
ts, profit = self._eval_one_choice()
|
||||
|
||||
# restore previous information
|
||||
host_region.bwd_prefetch_region = orig_prefetch
|
||||
offload_region.is_syn = orig_is_syn
|
||||
|
||||
return ts, profit
|
||||
|
||||
def _repair_strategy(self):
|
||||
"""
|
||||
Repair offload strategy.
|
||||
It attempts to convert asynchronous prefetch into synchronous upload operations and selects the best one.
|
||||
The repair process does not end until peak memory is reduced or there is no asynchronous prefetch operation.
|
||||
"""
|
||||
print("repair strategy ......")
|
||||
|
||||
peak_mem_saving = 0
|
||||
while len(self.region_to_region_map) and peak_mem_saving <= 0:
|
||||
|
||||
max_profit = (0,)
|
||||
best_ts = None
|
||||
undo_host_region = None
|
||||
undo_offload_region = None
|
||||
|
||||
for offload_region_id, host_region in self.region_to_region_map.items():
|
||||
offload_region = self.region_list[offload_region_id]
|
||||
assert host_region.bwd_prefetch_region == offload_region
|
||||
assert offload_region.need_offload
|
||||
assert not offload_region.is_syn
|
||||
|
||||
ts, profit = self._try_convert_to_syn_upload(host_region,
|
||||
offload_region)
|
||||
|
||||
if self._compare_profit(profit, max_profit):
|
||||
undo_host_region = host_region
|
||||
undo_offload_region = offload_region
|
||||
max_profit = profit
|
||||
best_ts = ts
|
||||
|
||||
if best_ts is None:
|
||||
raise NotImplementedError('repair error!')
|
||||
|
||||
assert not undo_offload_region.is_syn
|
||||
undo_offload_region.is_syn = True
|
||||
undo_host_region.bwd_prefetch_region = None
|
||||
|
||||
peak_mem_saving = self.best_ts.peak_mem - best_ts.peak_mem
|
||||
|
||||
self._update_state(best_ts)
|
||||
self.region_to_region_map.pop(undo_offload_region.r_id)
|
||||
|
||||
return best_ts
|
||||
|
||||
def _eval_one_choice(self):
|
||||
"""
|
||||
Evaluate the profit of a strategy choice.
|
||||
|
||||
Returns:
|
||||
AsynTrainingSimulator: the training simulator corresponding to the current strategy.
|
||||
tuple: contains memory saving and cost information of the current strategy.
|
||||
"""
|
||||
|
||||
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
|
||||
ts.execute()
|
||||
|
||||
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
|
||||
profit = self._compute_offload_profit(
|
||||
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
|
||||
|
||||
return ts, profit
|
||||
|
||||
|
||||
class SolverFactory:
|
||||
solvers: Dict[str, Type[Solver]] = {
|
||||
'syn': SynGreedySolver,
|
||||
'asyn': AsynGreedySolver
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create(solver_name: str) -> Type[Solver]:
|
||||
if solver_name not in SolverFactory.solvers:
|
||||
raise TypeError(f"Unknown parameter offload policy {solver_name}")
|
||||
return SolverFactory.solvers[solver_name]
|
||||
|
||||
@staticmethod
|
||||
def get_solver_names():
|
||||
return tuple(SolverFactory.solvers.keys())
|
|
@ -0,0 +1,458 @@
|
|||
import bisect
|
||||
from typing import List, Dict
|
||||
from collections import OrderedDict
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .region import Region
|
||||
from .util import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionPeriod:
|
||||
start_time: float = 0
|
||||
end_time: float = 0
|
||||
|
||||
|
||||
class TrainingSimulator(ABC):
|
||||
"""
|
||||
The Training Simulator is used to simulate the training process.
|
||||
It records computation, communication, and runtime memory during forward and backward passes.
|
||||
|
||||
Args:
|
||||
region_list (List[Region]): represents the linearized DNN computing graph.
|
||||
comp_power (float): the NVIDIA GPU FP16 compuing power.
|
||||
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
region_list: List[Region],
|
||||
comp_power: float,
|
||||
link_to_bw: Dict[str, Dict[float, float]]) -> None:
|
||||
self.region_list = region_list
|
||||
self.region_num = len(region_list)
|
||||
|
||||
self.runtime_mem: int = 0
|
||||
self.peak_mem: int = 0
|
||||
self.total_mem_saving: int = 0
|
||||
|
||||
self.fwd_node_mem: Dict[Node, float] = {}
|
||||
self.bwd_node_mem: Dict[Node, float] = {}
|
||||
|
||||
# Node dependencies in backward pass
|
||||
self.bwd_node_deps: Dict[Node, int] = {}
|
||||
|
||||
self.comp_power: float = comp_power
|
||||
self.link_to_bandwidth: Dict[str, Dict[float, float]] = link_to_bw
|
||||
|
||||
@abstractmethod
|
||||
def execute(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _eval_fwd_mem_per_region(self, region: Region):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _eval_bwd_mem_per_region(self, region: Region):
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_bandwidth(self, link: str, comm_volumn: float) -> float:
|
||||
"""
|
||||
Get the data transfer bandwidth.
|
||||
|
||||
Args:
|
||||
link (str): the data transfer link.
|
||||
comm_volumn (float): the amount of data transferred.
|
||||
|
||||
Returns:
|
||||
float: the data transfer bandwidth.
|
||||
"""
|
||||
|
||||
assert len(self.link_to_bandwidth)
|
||||
if link not in self.link_to_bandwidth:
|
||||
raise TypeError(f"Unknown data transfer link {link}")
|
||||
|
||||
# size_list = sorted(list(map(float, self.link_to_bandwidth[link].keys())))
|
||||
size_list = sorted(self.link_to_bandwidth[link].keys())
|
||||
d_idx = bisect.bisect_left(size_list, comm_volumn)
|
||||
return self.link_to_bandwidth[link][size_list[d_idx]]
|
||||
|
||||
def _get_communication_overhead(self, link: str, comm_volumn: float) -> float:
|
||||
return comm_volumn / self._get_bandwidth(link, comm_volumn)
|
||||
|
||||
def _get_computing_overhead(self, flop: float) -> float:
|
||||
return flop / self.comp_power
|
||||
|
||||
|
||||
class SynTrainingSimulator(TrainingSimulator):
|
||||
|
||||
def __init__(self,
|
||||
region_list: List[Region],
|
||||
comp_power: float,
|
||||
link_to_bw: Dict[str, Dict[float, float]]) -> None:
|
||||
super().__init__(region_list, comp_power, link_to_bw)
|
||||
|
||||
def execute(self):
|
||||
"""
|
||||
Simulate synchronous training process.
|
||||
"""
|
||||
|
||||
for reg in self.region_list:
|
||||
self._eval_fwd_mem_per_region(reg)
|
||||
|
||||
for reg in self.region_list.__reversed__():
|
||||
self._eval_bwd_mem_per_region(reg)
|
||||
|
||||
def _eval_fwd_mem_per_region(self, region: Region):
|
||||
"""
|
||||
Evaluate the runtime and peak memory when the forward execution reaches the current region.
|
||||
"""
|
||||
|
||||
# upload parameters of the current region
|
||||
if requires_upload_p_in_fwd(self.region_list[region.shared_rid]):
|
||||
self.runtime_mem += region.param_size
|
||||
|
||||
for node in region.nodes:
|
||||
self.runtime_mem += calculate_fwd_tmp(node) + \
|
||||
calculate_fwd_out(node)
|
||||
self.fwd_node_mem[node] = self.runtime_mem
|
||||
self.peak_mem = max(self.runtime_mem, self.peak_mem)
|
||||
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
|
||||
|
||||
if region.need_offload:
|
||||
self.runtime_mem -= region.param_size
|
||||
|
||||
def _eval_bwd_mem_per_region(self, region: Region):
|
||||
"""
|
||||
Evaluate the runtime and peak memory when the backward execution reaches the current region.
|
||||
"""
|
||||
|
||||
# upload parameters of the current region
|
||||
if region.need_offload:
|
||||
self.runtime_mem += region.param_size
|
||||
|
||||
# add the gradient of the parameter
|
||||
if region.r_id < region.shared_rid:
|
||||
# gradient accumulation is required for shared parameters
|
||||
self.runtime_mem += 2.0 * region.param_size
|
||||
else:
|
||||
self.runtime_mem += region.param_size
|
||||
|
||||
for node in region.nodes.__reversed__():
|
||||
|
||||
self.runtime_mem -= calculate_fwd_out(node)
|
||||
self.runtime_mem += node.meta['bwd_mem_tmp'] + \
|
||||
node.meta['bwd_mem_out']
|
||||
self.peak_mem = max(self.runtime_mem, self.peak_mem)
|
||||
|
||||
# The memory savings of a node may be negative due to parameter prefetch.
|
||||
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
|
||||
self.bwd_node_mem[node] = self.runtime_mem
|
||||
|
||||
self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
|
||||
calculate_fwd_tmp(node))
|
||||
|
||||
# free bwd_mem_out
|
||||
self.bwd_node_deps[node] = len(node.all_input_nodes)
|
||||
for user_node in node.users:
|
||||
if user_node in self.bwd_node_deps:
|
||||
self.bwd_node_deps[user_node] -= 1
|
||||
if self.bwd_node_deps[user_node] <= 0:
|
||||
self.runtime_mem -= user_node.meta['bwd_mem_out']
|
||||
|
||||
if self.runtime_mem < 0:
|
||||
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
|
||||
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
|
||||
f"runtime memory computed less than 0, which is miscalculated!")
|
||||
|
||||
# release parameter and offload gradient in region
|
||||
if region.r_id == region.shared_rid:
|
||||
self.runtime_mem -= 2.0 * region.param_size
|
||||
elif region.r_id < region.shared_rid:
|
||||
self.runtime_mem -= 3.0 * region.param_size
|
||||
elif self.region_list[region.shared_rid].need_offload:
|
||||
self.runtime_mem -= region.param_size
|
||||
|
||||
|
||||
class AsynTrainingSimulator(TrainingSimulator):
|
||||
|
||||
def __init__(self,
|
||||
region_list: List[Region],
|
||||
comp_power: float,
|
||||
link_to_bw: Dict[str, Dict[float, float]]) -> None:
|
||||
super().__init__(region_list, comp_power, link_to_bw)
|
||||
|
||||
self.iter_end_time: int = 0
|
||||
# the last computation execution period
|
||||
self.last_comp: ExecutionPeriod = ExecutionPeriod(
|
||||
start_time=0, end_time=0)
|
||||
# the last parameter prefetch execution period
|
||||
self.last_h2d: ExecutionPeriod = ExecutionPeriod(
|
||||
start_time=0, end_time=0)
|
||||
# the last gradient offload execution period
|
||||
self.last_d2h: ExecutionPeriod = ExecutionPeriod(
|
||||
start_time=0, end_time=0)
|
||||
# the forward computation execution period of the region
|
||||
self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
|
||||
# the forward parameter prefetch execution period of the region
|
||||
self.fwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
|
||||
# the backward computation execution period of the region
|
||||
self.bwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
|
||||
# the backward parameter prefetch execution period of the region
|
||||
self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
|
||||
# the gradient offload execution period of the region
|
||||
# which is divided into those that are waiting and those that have been released
|
||||
self.bwd_reg_to_offl_waiting: OrderedDict[int,
|
||||
ExecutionPeriod] = OrderedDict()
|
||||
self.bwd_reg_to_offl_freed: OrderedDict[int,
|
||||
ExecutionPeriod] = OrderedDict()
|
||||
# the region buffer, which records regions that are offloaded but not released
|
||||
self.reg_buffer_to_free: List[int] = []
|
||||
|
||||
# node dependencies in backward pass
|
||||
self.bwd_node_deps: Dict[Node, int] = {}
|
||||
|
||||
# the region execution flow,
|
||||
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
|
||||
# when the execution reaches the i-th region.
|
||||
self.fwd_reg_flow = torch.zeros(
|
||||
(self.region_num, self.region_num)).bool()
|
||||
self.bwd_reg_flow = torch.zeros(
|
||||
(self.region_num, self.region_num)).bool()
|
||||
|
||||
def execute(self):
|
||||
"""
|
||||
Simulate asynchronous training process.
|
||||
In forward pass, parameter prefetching is advanced by one region.
|
||||
In backward pass, parameter prefetching is executed at the specified location,
|
||||
and gradient offloading is urgent.
|
||||
"""
|
||||
|
||||
for reg in self.region_list:
|
||||
if reg.param_size and reg.r_id < self.region_num - 1:
|
||||
for nr in self.region_list[reg.r_id + 1:]:
|
||||
if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]):
|
||||
reg.fwd_prefetch_region = nr
|
||||
break
|
||||
self._eval_fwd_cost_per_region(reg)
|
||||
self._eval_fwd_mem_per_region(reg)
|
||||
|
||||
for reg in self.region_list.__reversed__():
|
||||
self._eval_bwd_cost_per_region(reg)
|
||||
self._eval_bwd_mem_per_region(reg)
|
||||
|
||||
# release remaining grads
|
||||
for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items():
|
||||
self.bwd_reg_to_offl_freed[reg_id] = offl_exec
|
||||
self.runtime_mem -= self.region_list[reg_id].param_size
|
||||
self.bwd_reg_to_offl_waiting.clear()
|
||||
|
||||
self.iter_end_time = max(
|
||||
self.last_comp.end_time, self.last_d2h.end_time)
|
||||
|
||||
def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
|
||||
"""
|
||||
Insert parameter prefetch execution period of the current region to the end of the h2d stream
|
||||
"""
|
||||
|
||||
pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time)
|
||||
pref_end_time = pref_start_time + \
|
||||
2.0 * self._get_communication_overhead('h2d', region.param_size)
|
||||
pref_ep = ExecutionPeriod(
|
||||
start_time=pref_start_time, end_time=pref_end_time)
|
||||
if is_fwd:
|
||||
self.fwd_reg_to_pref[region.r_id] = pref_ep
|
||||
else:
|
||||
self.bwd_reg_to_pref[region.r_id] = pref_ep
|
||||
self.last_h2d = pref_ep
|
||||
|
||||
def _insert_comp_exec(self, region: Region, is_fwd: bool = True):
|
||||
"""
|
||||
Insert computation execution period of the current region to the end of the computing stream
|
||||
"""
|
||||
|
||||
if is_fwd:
|
||||
reg_to_comp = self.fwd_reg_to_comp
|
||||
reg_to_pref = self.fwd_reg_to_pref
|
||||
flop_key = 'fwd_flop'
|
||||
else:
|
||||
reg_to_comp = self.bwd_reg_to_comp
|
||||
reg_to_pref = self.bwd_reg_to_pref
|
||||
flop_key = 'bwd_flop'
|
||||
comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(
|
||||
region.r_id, ExecutionPeriod(0, 0)).end_time)
|
||||
comp_end_time = comp_start_time + \
|
||||
sum([self._get_computing_overhead(node.meta.get(flop_key, 0))
|
||||
for node in region.nodes])
|
||||
comp_ep = ExecutionPeriod(
|
||||
start_time=comp_start_time, end_time=comp_end_time)
|
||||
reg_to_comp[region.r_id] = comp_ep
|
||||
self.last_comp = comp_ep
|
||||
|
||||
def _insert_d2h_exec(self, region: Region):
|
||||
"""
|
||||
Insert gradient offload execution period of the current region to the end of the d2h stream
|
||||
"""
|
||||
|
||||
offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time)
|
||||
offl_end_time = offl_start_time + \
|
||||
self._get_communication_overhead('d2h', region.param_size)
|
||||
offl_ep = ExecutionPeriod(
|
||||
start_time=offl_start_time, end_time=offl_end_time)
|
||||
self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep
|
||||
self.last_d2h = offl_ep
|
||||
|
||||
def _eval_fwd_cost_per_region(self, region: Region):
|
||||
"""
|
||||
Evaluate computation and communication execution period of the region in forward pass.
|
||||
"""
|
||||
|
||||
# upload parameters of the first region
|
||||
if region.r_id == 0:
|
||||
self._insert_h2d_exec(region)
|
||||
|
||||
# prefetch parameters of the next region
|
||||
fwd_prefetch_region = region.fwd_prefetch_region
|
||||
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
|
||||
self._insert_h2d_exec(fwd_prefetch_region)
|
||||
|
||||
# execute computation
|
||||
self._insert_comp_exec(region)
|
||||
|
||||
def _eval_fwd_mem_per_region(self, region: Region):
|
||||
"""
|
||||
Evaluate the runtime and peak memory when the forward execution reaches the current region.
|
||||
"""
|
||||
|
||||
# upload parameters of the current region
|
||||
if region.r_id <= 0:
|
||||
self.runtime_mem += region.param_size
|
||||
self.fwd_reg_flow[region.r_id, region.r_id] = True
|
||||
else:
|
||||
self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1]
|
||||
self.fwd_reg_flow[region.r_id,
|
||||
self.reg_buffer_to_free] = False
|
||||
self.reg_buffer_to_free.clear()
|
||||
|
||||
# prefetch parameters of the next region
|
||||
fwd_prefetch_region = region.fwd_prefetch_region
|
||||
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
|
||||
self.runtime_mem += fwd_prefetch_region.param_size
|
||||
self.fwd_reg_flow[region.r_id,
|
||||
fwd_prefetch_region.r_id] = True
|
||||
|
||||
for node in region.nodes:
|
||||
self.runtime_mem += calculate_fwd_tmp(node) + \
|
||||
calculate_fwd_out(node)
|
||||
self.peak_mem = max(self.runtime_mem, self.peak_mem)
|
||||
|
||||
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
|
||||
self.fwd_node_mem[node] = self.runtime_mem
|
||||
|
||||
if region.need_offload:
|
||||
self.runtime_mem -= region.param_size
|
||||
|
||||
assert len(
|
||||
self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}'
|
||||
self.reg_buffer_to_free.append(region.r_id)
|
||||
|
||||
def _eval_bwd_cost_per_region(self, region: Region):
|
||||
"""
|
||||
Evaluate computation and communication execution period of the region in backward pass.
|
||||
"""
|
||||
|
||||
# upload parameters of the current region
|
||||
if region.is_syn:
|
||||
assert region.need_offload
|
||||
self._insert_h2d_exec(region, is_fwd=False)
|
||||
|
||||
# prefetch parameters of the region choiced, which is parallel to computation
|
||||
if region.bwd_prefetch_region is not None:
|
||||
self._insert_h2d_exec(region.bwd_prefetch_region, is_fwd=False)
|
||||
|
||||
# execute computation
|
||||
self._insert_comp_exec(region, is_fwd=False)
|
||||
|
||||
# offload gradient
|
||||
if requires_offload_g_in_bwd(region):
|
||||
self._insert_d2h_exec(region)
|
||||
|
||||
assert len(self.reg_buffer_to_free) == 0
|
||||
for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items():
|
||||
if offl_exec.end_time >= self.last_comp.start_time:
|
||||
break
|
||||
self.reg_buffer_to_free.append(reg_id)
|
||||
self.bwd_reg_to_offl_freed[reg_id] = offl_exec
|
||||
|
||||
for reg_id in self.reg_buffer_to_free:
|
||||
self.bwd_reg_to_offl_waiting.pop(reg_id)
|
||||
|
||||
def _eval_bwd_mem_per_region(self, region: Region):
|
||||
"""
|
||||
Evaluate the runtime and peak memory when the backward execution reaches the current region.
|
||||
"""
|
||||
|
||||
if region.r_id + 1 < self.region_num:
|
||||
self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1]
|
||||
else:
|
||||
self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1]
|
||||
self.bwd_reg_flow[region.r_id,
|
||||
self.reg_buffer_to_free] = False
|
||||
|
||||
# free gradients in the buffer
|
||||
while len(self.reg_buffer_to_free):
|
||||
reg_id = self.reg_buffer_to_free.pop(0)
|
||||
self.runtime_mem -= self.region_list[reg_id].param_size
|
||||
|
||||
# upload parameters of the current region
|
||||
if region.is_syn:
|
||||
self.runtime_mem += region.param_size
|
||||
self.bwd_reg_flow[region.r_id, region.r_id] = True
|
||||
|
||||
# prefetch parameters of the region choiced
|
||||
bwd_prefetch_region = region.bwd_prefetch_region
|
||||
if bwd_prefetch_region:
|
||||
self.runtime_mem += bwd_prefetch_region.param_size
|
||||
self.bwd_reg_flow[region.r_id,
|
||||
bwd_prefetch_region.r_id] = True
|
||||
|
||||
# add the gradient of the parameter
|
||||
if region.r_id < region.shared_rid:
|
||||
# gradient accumulation is required for shared parameters
|
||||
self.runtime_mem += 2.0 * region.param_size
|
||||
else:
|
||||
self.runtime_mem += region.param_size
|
||||
|
||||
for node in region.nodes.__reversed__():
|
||||
|
||||
self.runtime_mem -= calculate_fwd_out(node)
|
||||
self.runtime_mem += node.meta['bwd_mem_tmp'] + \
|
||||
node.meta['bwd_mem_out']
|
||||
self.peak_mem = max(self.runtime_mem, self.peak_mem)
|
||||
|
||||
# The memory savings of a node may be negative due to parameter prefetch.
|
||||
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
|
||||
|
||||
self.bwd_node_mem[node] = self.runtime_mem
|
||||
|
||||
self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
|
||||
calculate_fwd_tmp(node))
|
||||
|
||||
# free bwd_mem_out
|
||||
self.bwd_node_deps[node] = len(node.all_input_nodes)
|
||||
for user_node in node.users:
|
||||
if user_node in self.bwd_node_deps:
|
||||
self.bwd_node_deps[user_node] -= 1
|
||||
if self.bwd_node_deps[user_node] <= 0:
|
||||
self.runtime_mem -= user_node.meta['bwd_mem_out']
|
||||
|
||||
if self.runtime_mem < 0:
|
||||
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
|
||||
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
|
||||
f"runtime memory computed less than 0, which is miscalculated!")
|
||||
|
||||
# release parameters of the region
|
||||
if requires_release_p_in_bwd(self.region_list[region.shared_rid]):
|
||||
self.runtime_mem -= region.param_size
|
|
@ -0,0 +1,90 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
import torch
|
||||
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp
|
||||
|
||||
from .region import Region
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeInfo:
|
||||
node_id: int = 0
|
||||
runtime_fwd_mem: float = 0
|
||||
runtime_bwd_mem: float = 0
|
||||
|
||||
class NvDevicePower:
|
||||
"""
|
||||
NVIDIA GPU computing performance (TFLOPs).
|
||||
"""
|
||||
|
||||
RTX3080_FP16 = 70
|
||||
RTX3080_FP32 = 34.1
|
||||
|
||||
RTX3090_FP16 = 71
|
||||
RTX3090_FP32 = 35.7
|
||||
|
||||
V100_FP16 = 31.4
|
||||
V100_FP32 = 15.7
|
||||
|
||||
A100_FP16 = 78
|
||||
A100_FP32 = 19.5
|
||||
|
||||
|
||||
class GlobalRuntimeInfo:
|
||||
h2d_stream = torch.cuda.Stream()
|
||||
d2h_stream = torch.cuda.Stream()
|
||||
fwd_prefetch_event_map = {}
|
||||
bwd_prefetch_event_map = {}
|
||||
region_list = []
|
||||
|
||||
|
||||
def compute_act_peak_mem(region_list: List[Region]) -> float:
|
||||
act_peak_mem = 0
|
||||
runtime_mem = 0
|
||||
# forward
|
||||
for region in region_list:
|
||||
for node in region.nodes:
|
||||
runtime_mem = runtime_mem + \
|
||||
calculate_fwd_tmp(node) + calculate_fwd_out(node)
|
||||
act_peak_mem = max(runtime_mem, act_peak_mem)
|
||||
# backward
|
||||
bwd_deps = {}
|
||||
for region in region_list.__reversed__():
|
||||
for node in region.nodes.__reversed__():
|
||||
runtime_mem -= calculate_fwd_out(node)
|
||||
runtime_mem = runtime_mem + \
|
||||
node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out']
|
||||
|
||||
act_peak_mem = max(runtime_mem, act_peak_mem)
|
||||
|
||||
runtime_mem = runtime_mem - \
|
||||
node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node)
|
||||
|
||||
# free bwd_mem_out
|
||||
bwd_deps[node] = len(node.all_input_nodes)
|
||||
for user_node in node.users:
|
||||
if user_node in bwd_deps:
|
||||
bwd_deps[user_node] -= 1
|
||||
if bwd_deps[user_node] <= 0:
|
||||
runtime_mem -= user_node.meta['bwd_mem_out']
|
||||
|
||||
return act_peak_mem
|
||||
|
||||
def compute_max_param_mem(region_list: List[Region]) -> float:
|
||||
return max(region.param_size for region in region_list)
|
||||
|
||||
def compute_total_param_mem(region_list: List[Region]) -> float:
|
||||
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid)
|
||||
|
||||
def requires_upload_p_in_fwd(shared_reg: Region):
|
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or (
|
||||
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
|
||||
|
||||
def requires_release_p_in_bwd(shared_reg: Region):
|
||||
return (shared_reg.r_id >= shared_reg.shared_rid) or (
|
||||
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
|
||||
|
||||
def requires_offload_g_in_bwd(region: Region):
|
||||
return region.param_size and (region.r_id <= region.shared_rid)
|
||||
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# Auto-Offload Demo with GPT2
|
||||
|
||||
## Requirements
|
||||
|
||||
Before you can launch training, you need to install the following requirements.
|
||||
|
||||
### Install PyTorch
|
||||
|
||||
```bash
|
||||
#conda
|
||||
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
|
||||
#pip
|
||||
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
```
|
||||
|
||||
### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website
|
||||
|
||||
```bash
|
||||
pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
|
||||
```
|
||||
|
||||
### Install transformers
|
||||
|
||||
```bash
|
||||
pip install transformers
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
||||
For simplicity, the input data is randonly generated here.
|
||||
|
||||
## Training
|
||||
|
||||
```bash
|
||||
#Run the auto offload on GPT with default setting and a dummy dataset.
|
||||
bash run.sh
|
||||
```
|
|
@ -0,0 +1,65 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257):
|
||||
super().__init__()
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
def get_gpt2_components(model_type: str, batch_size: int):
|
||||
vocab_size = 1024
|
||||
seq_len = 8
|
||||
|
||||
def gpt2_model_builder():
|
||||
if model_type == "gpt2_medium":
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16)
|
||||
elif model_type == "gpt2_xl":
|
||||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32)
|
||||
elif model_type == "gpt2_10b":
|
||||
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16)
|
||||
elif model_type == "gpt2_14b":
|
||||
return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16)
|
||||
elif model_type == "gpt2_20b":
|
||||
return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16)
|
||||
elif model_type == "gpt2_24b":
|
||||
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16)
|
||||
else:
|
||||
raise TypeError(f"model_builder {model_type}")
|
||||
|
||||
def gpt2_data_gen(device="cuda"):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
attention_mask = torch.ones_like(input_ids, device=device)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
return gpt2_model_builder, gpt2_data_gen
|
|
@ -0,0 +1,2 @@
|
|||
colossalai >= 0.1.12
|
||||
torch >= 1.8.1
|
|
@ -0,0 +1,8 @@
|
|||
export BATCH_SIZE=${BATCH_SIZE:-64}
|
||||
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
||||
export MEMORY_BUDGET=${MEMORY_BUDGET:-16}
|
||||
export SOLVER_TYPE=${SOLVER_TYPE:-"asyn"}
|
||||
|
||||
mkdir -p offload_logs
|
||||
|
||||
python train_gpt_offload.py --model_type=${MODEL_TYPE} --memory_budget=${MEMORY_BUDGET} --solver_type=${SOLVER_TYPE} --batch_size=${BATCH_SIZE} 2>&1 | tee ./offload_logs/${MODEL_TYPE}_bs_${BATCH_SIZE}_st_${SOLVER_TYPE}.log
|
|
@ -0,0 +1,94 @@
|
|||
import time
|
||||
import pytest
|
||||
import argparse
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
|
||||
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
|
||||
from colossalai.auto_parallel.offload.solver import NOT_NVML
|
||||
from model_zoo import get_gpt2_components, GPTLMLoss
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_type', type=str, default="gpt2_medium")
|
||||
parser.add_argument('--batch_size', type=int, default=64)
|
||||
parser.add_argument('--solver_type', type=str, default='asyn')
|
||||
parser.add_argument('--memory_budget', type=float, default=16)
|
||||
return parser.parse_args()
|
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
|
||||
def train_gpt(args):
|
||||
memory_budget = args.memory_budget * 1024 * 1024 * 1024
|
||||
solver_type = args.solver_type
|
||||
model_type = args.model_type
|
||||
batch_size = args.batch_size
|
||||
|
||||
# build model
|
||||
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
|
||||
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
start_time = time.time()
|
||||
model = model_builder()
|
||||
model.train()
|
||||
param_size = parameter_size(model) / 1024 ** 2 / 2
|
||||
init_time = time.time() - start_time
|
||||
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
|
||||
|
||||
data_args = data_gen(device="cpu")
|
||||
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
|
||||
data_args = tree_map(wrap_fn, data_args)
|
||||
start_time = time.time()
|
||||
model = memory_optimize(model, data_args, memory_budget, solver_type)
|
||||
solver_time = time.time() - start_time
|
||||
print(f"solver_time={solver_time:.3f} s")
|
||||
|
||||
hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
|
||||
optim = AMPOptimizer(hybrid_optimizer, model)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
time_list = []
|
||||
data_args = data_gen(device="cuda")
|
||||
data_args = tree_map(wrap_fn, data_args)
|
||||
for step in range(10):
|
||||
optim.zero_grad()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
loss = criterion(model(**data_args), label)
|
||||
optim.backward(loss)
|
||||
torch.cuda.synchronize()
|
||||
time_list.append(time.time() - start_time)
|
||||
optim.step()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
exec_time = sum(sorted(time_list)[:5]) / 5
|
||||
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
|
||||
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
|
||||
print(f'solver_type: {solver_type} | model_type: {model_type}')
|
||||
print(
|
||||
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
|
||||
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
|
||||
)
|
||||
print(time_list)
|
||||
|
||||
def run(rank, world_size, port, args):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
train_gpt(args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
run_func = partial(run, world_size=1, port=free_port(), args=args)
|
||||
mp.spawn(run_func, nprocs=1)
|
|
@ -0,0 +1,86 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
from transformers import BertConfig, BertLMHeadModel
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257):
|
||||
super().__init__()
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
|
||||
|
||||
|
||||
class LMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
class BertLMModel(nn.Module):
|
||||
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522):
|
||||
super().__init__()
|
||||
self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size,
|
||||
vocab_size=vocab_size))
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
|
||||
|
||||
@non_distributed_component_funcs.register(name='bert_')
|
||||
def get_bert_components():
|
||||
vocab_size = 1024
|
||||
seq_len = 64
|
||||
batchSize = 64
|
||||
|
||||
def bert_model_builder():
|
||||
model = BertLMModel(hidden_size=8192, num_layers=4, num_attention_heads=32, vocab_size=vocab_size)
|
||||
return model
|
||||
|
||||
def bert_data_gen(device="meta"):
|
||||
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device)
|
||||
attention_mask = torch.ones_like(input_ids, device=device)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
return bert_model_builder, bert_data_gen
|
||||
|
||||
@non_distributed_component_funcs.register(name='gpt2_')
|
||||
def get_gpt2_components():
|
||||
vocab_size = 1024
|
||||
seq_len = 8
|
||||
batchSize = 64
|
||||
|
||||
def gpt2_model_builder():
|
||||
model = GPTLMModel(hidden_size=8192, num_layers=2, num_attention_heads=32, vocab_size=vocab_size)
|
||||
return model
|
||||
|
||||
def gpt2_data_gen(device="meta"):
|
||||
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device)
|
||||
attention_mask = torch.ones_like(input_ids, device=device)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
return gpt2_model_builder, gpt2_data_gen
|
|
@ -0,0 +1,150 @@
|
|||
import time
|
||||
import pytest
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
|
||||
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
|
||||
from colossalai.auto_parallel.offload.solver import NOT_NVML
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
from tests.test_auto_parallel.test_offload.model_utils import *
|
||||
|
||||
|
||||
@parameterize('model_name', ['gpt2_'])
|
||||
@parameterize('memory_budget', [5000])
|
||||
@parameterize('solver_name', ['asyn'])
|
||||
def exam_fwd_bwd(
|
||||
model_name: str,
|
||||
memory_budget: float,
|
||||
solver_name: str
|
||||
):
|
||||
|
||||
# build model
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, data_gen = get_components_func()
|
||||
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
|
||||
criterion = LMLoss()
|
||||
|
||||
set_seed(42)
|
||||
start_time = time.time()
|
||||
model = model_builder()
|
||||
model.train()
|
||||
param_size = parameter_size(model) / 1024 ** 2 / 2
|
||||
init_time = time.time() - start_time
|
||||
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
|
||||
|
||||
data_args = data_gen(device="cpu")
|
||||
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
|
||||
data_args = tree_map(wrap_fn, data_args)
|
||||
start_time = time.time()
|
||||
model = memory_optimize(model, data_args, memory_budget * 1024 * 1024, solver_name)
|
||||
solver_time = time.time() - start_time
|
||||
print(f"solver_time={solver_time:.3f} s")
|
||||
|
||||
hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
|
||||
optim = AMPOptimizer(hybrid_optimizer, model)
|
||||
|
||||
with ColoInitContext(device=torch.device('cpu')):
|
||||
gemini_model = model_builder()
|
||||
gemini_model.train()
|
||||
|
||||
hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
|
||||
gemini_config = dict(strict_ddp_mode=False,
|
||||
device=torch.device('cpu'),
|
||||
placement_policy='cpu',
|
||||
pin_memory=True,
|
||||
hidden_dim=8192,
|
||||
search_range_mb=128)
|
||||
gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
|
||||
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
|
||||
gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# test gemini
|
||||
time_list = []
|
||||
set_seed(42)
|
||||
data_args = data_gen(device="cuda")
|
||||
for step in range(10):
|
||||
gemini_optim.zero_grad()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
gemini_out = gemini_model(**data_args)
|
||||
gemini_loss = criterion(gemini_out, label)
|
||||
gemini_optim.backward(gemini_loss)
|
||||
torch.cuda.synchronize()
|
||||
time_list.append(time.time() - start_time)
|
||||
gemini_optim.step()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
exec_time = sum(sorted(time_list)[:5]) / 5
|
||||
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
|
||||
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
|
||||
print(f'gemini | model_name: {model_name}')
|
||||
print(
|
||||
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
|
||||
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
|
||||
)
|
||||
print(time_list)
|
||||
|
||||
del data_args
|
||||
del gemini_model
|
||||
del gemini_optim
|
||||
del gemini_out
|
||||
del gemini_loss
|
||||
|
||||
# test asyn offload
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
time_list = []
|
||||
set_seed(42)
|
||||
data_args = data_gen(device="cuda")
|
||||
data_args = tree_map(wrap_fn, data_args)
|
||||
for step in range(10):
|
||||
optim.zero_grad()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
loss = criterion(model(**data_args), label)
|
||||
optim.backward(loss)
|
||||
torch.cuda.synchronize()
|
||||
time_list.append(time.time() - start_time)
|
||||
optim.step()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
exec_time = sum(sorted(time_list)[:5]) / 5
|
||||
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
|
||||
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
|
||||
print(f'solver_name: {solver_name} | model_name: {model_name}')
|
||||
print(
|
||||
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
|
||||
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
|
||||
)
|
||||
print(time_list)
|
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
|
||||
def test_perf(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_fwd_bwd()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_func = partial(test_perf, world_size=1, port=free_port())
|
||||
mp.spawn(run_func, nprocs=1)
|
|
@ -0,0 +1,62 @@
|
|||
import pytest
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.fx import ColoTracer, is_compatible_with_meta
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.auto_parallel.offload.region_manager import RegionManager
|
||||
from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML
|
||||
from colossalai.testing import parameterize
|
||||
from tests.test_auto_parallel.test_offload.model_utils import *
|
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
|
||||
@parameterize('model_name', ['gpt2_', 'bert_'])
|
||||
@parameterize('memory_budget', [4000])
|
||||
@parameterize('solver_name', ['syn', 'asyn'])
|
||||
def solver_test(model_name: str,
|
||||
memory_budget: float,
|
||||
solver_name: str):
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, data_gen = get_components_func()
|
||||
data_args = data_gen(device="cpu")
|
||||
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
|
||||
data_args = tree_map(wrap_fn, data_args)
|
||||
model = model_builder()
|
||||
model.train()
|
||||
model = model.cpu().half()
|
||||
|
||||
tracer = ColoTracer()
|
||||
assert is_compatible_with_meta()
|
||||
wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x
|
||||
meta_args = tree_map(wrap_fn, data_args)
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(*meta_args.values())
|
||||
|
||||
region_manager = RegionManager(graph, solver_name=solver_name)
|
||||
region_manager._pre_process()
|
||||
region_list = region_manager.region_list
|
||||
|
||||
solver_cls = SolverFactory.create(solver_name)
|
||||
memory_budget = memory_budget * 1024 * 1024
|
||||
solver = solver_cls(region_list, memory_budget)
|
||||
solver._call_solver()
|
||||
|
||||
assert solver.best_ts.peak_mem < memory_budget
|
||||
|
||||
print("****************** execution plan *******************")
|
||||
for region in region_list:
|
||||
need_offload = region.need_offload
|
||||
to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None
|
||||
print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
|
||||
for region in region_list.__reversed__():
|
||||
need_offload = region.need_offload
|
||||
to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None
|
||||
print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
solver_test()
|
Loading…
Reference in New Issue