[auto-parallel] add auto-offload feature (#3154)

* add auto-offload feature

* polish code

* fix syn offload runtime pass bug

* add offload example

* fix offload testing bug

* fix example testing bug
pull/3190/head^2
Zihao 2023-03-21 14:17:41 +08:00 committed by GitHub
parent 258b43317c
commit 18dbe76cae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 2833 additions and 0 deletions

View File

@ -0,0 +1,177 @@
from typing import Dict, Tuple
from enum import Enum
import torch
from torch.optim import Optimizer
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule
from .region_manager import RegionManager
from .region import Region
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class AMPOptimizer(ColossalaiOptimizer):
"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
Args:
optimizer (Optimizer): An Optimizer instance.
module (BaseOffloadModule): A ``BaseOffloadModule`` instance.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
"""
def __init__(self,
optimizer: Optimizer,
module: BaseOffloadModule,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
clipping_norm: float = 0.0,
norm_type: float = 2.0):
super().__init__(optimizer)
self.module = module
self.optim_state = OptimState.UNSCALED
self.clipping_flag = clipping_norm > 0.0
self.max_norm = clipping_norm
self.region_manager: RegionManager = self.module.region_manager
self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict()
self.param_to_region: Dict[torch.nn.Parameter, Region] = dict()
self.fp32_to_fp16_params: Dict[torch.Tensor, torch.nn.Parameter] = dict()
if self.clipping_flag:
assert norm_type == 2.0, "AMPOptimizer only supports L2 norm now"
self.__init__optimizer()
# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
self._logger = get_dist_logger()
def _set_grad_ptr(self):
for group in self.param_groups:
for fake_param in group['params']:
region = self.param_to_region[fake_param]
begin, end = self.param_to_range[fake_param]
fake_param.data = region.cpu_grad[begin:end]
fake_param.grad = fake_param.data
fake_param.data = region.fp32_data[begin:end]
def _update_fp16_params(self):
none_tensor = torch.empty([0])
for group in self.param_groups:
for fake_param in group['params']:
assert fake_param.grad is None
fake_param.data = none_tensor
self.param_to_region[fake_param].cpu_grad = None
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(self.module.overflow_counter.item())
return self._found_overflow.item() > 0
def _get_combined_scale(self):
loss_scale = 1
if self.optim_state == OptimState.SCALED:
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
combined_scale = loss_scale
if combined_scale == 1:
return -1
else:
return combined_scale
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
def zero_grad(self, *args, **kwargs):
self.module.overflow_counter = torch.cuda.IntTensor([0])
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):
# Copy gradients from model params to main params.
self._set_grad_ptr()
found_inf = self._check_overflow()
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self.zero_grad() # reset all gradients
self._update_fp16_params()
return
# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf)
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self.zero_grad()
self._update_fp16_params()
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
raise NotImplementedError
def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self.module.backward(loss)
def __init__optimizer(self):
for group in self.optim.param_groups:
fake_params_list = list()
for param in group['params']:
region = self.region_manager.get_region(param)
fake_param = torch.nn.Parameter(torch.empty([0]))
self.param_to_range[fake_param] = region.param_to_range[param]
self.param_to_region[fake_param] = region
fake_params_list.append(fake_param)
# Reset existing state dict key to the new main param.
if param in self.optim.state:
self.optim.state[fake_param] = self.optim.state.pop(param)
group['params'] = fake_params_list
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optim.load_state_dict(self.optim.state_dict())

View File

@ -0,0 +1,109 @@
from typing import Optional, Set
from functools import partial
import torch
import torch.nn as nn
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.gemini.tensor_utils import free_storage
from .region_manager import RegionManager
from .util import GlobalRuntimeInfo
class BaseOffloadModule:
"""
BaseOffloadModule: A model wrapper for parameter offloading.
Args:
model (nn.Module): model to apply offloading.
region_manager (RegionManager): a ``RegionManager`` instance.
is_sync (bool): synchronous mode or not.
"""
def __init__(self,
model: nn.Module,
region_manager: RegionManager,
is_sync=True):
self.model = model
self.region_manager = region_manager
self.grad_hook_list = []
self.overflow_counter = torch.cuda.IntTensor([0])
self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream
self._cast_buffers()
def register_grad_hook(self):
for p in self.model.parameters():
if p.requires_grad:
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
def remove_grad_hook(self):
for hook in self.grad_hook_list:
hook.remove()
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def _pre_forward(self):
self.register_grad_hook()
for region in self.region_manager.region_list:
region.cpu_grad = None
def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.model.zero_grad(set_to_none=True)
self._pre_forward()
outputs = self.model(*args, **kwargs)
return outputs
def backward(self, loss):
loss.backward()
self._post_backward()
def _post_backward(self):
torch.cuda.synchronize()
self.remove_grad_hook()
for p in self.model.parameters():
p.grad = None
GlobalRuntimeInfo.fwd_prefetch_event_map.clear()
GlobalRuntimeInfo.bwd_prefetch_event_map.clear()
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
region = self.region_manager.get_region(p)
region.copy_grad_to_region_slice(p, grad)
if region.can_release:
self.overflow_counter += region.has_inf_or_nan
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.grad_offload_stream):
GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream)
region.move_grad_to_cpu()
return empty_grad
def _cast_buffers(self):
for buffer in self.model.buffers():
buffer.data = buffer.cuda()
def parameters(self, recurse: bool = True):
return self.model.parameters(recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True):
return self.model.named_parameters(prefix, recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True):
return self.model.named_buffers(prefix, recurse)
def named_children(self):
return self.model.named_children()
def named_modules(self,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
return self.model.named_modules(memo, prefix, remove_duplicate)

View File

@ -0,0 +1,49 @@
from typing import Dict
import torch
import torch.fx
from torch.fx import GraphModule
from torch.utils._pytree import tree_map
from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from .region_manager import RegionManager
from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass
from .base_offload_module import BaseOffloadModule
from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo
def memory_optimize(model: torch.nn.Module,
inps: Dict[str, torch.Tensor],
memory_budget: float = -1.0,
solver_name: str = 'asyn'):
model = model.cpu().half()
tracer = ColoTracer()
assert is_compatible_with_meta()
wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x
meta_args = tree_map(wrap_fn, inps)
graph = tracer.trace(model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
interp = MetaInfoProp(gm)
interp.propagate(*meta_args.values())
region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget)
region_manager._build_regions()
GlobalRuntimeInfo.region_list = region_manager.region_list
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2
print(
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}")
if solver_name == 'syn':
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
elif solver_name == 'asyn':
gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list)
else:
raise TypeError(f"Unknown solver name {solver_name}!")
gm.recompile()
optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn')
return optimized_model

View File

@ -0,0 +1,144 @@
from typing import List, Dict, Tuple
import torch
from torch.fx import Node
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
class Region:
"""
Region: A container owning a piece of contiguous nodes in the DNN computing graph.
Args:
r_id (int): the index of the region in the computing graph.
"""
def __init__(self, r_id: int = 0) -> None:
self.r_id: int = r_id
self.fp16_params: List[torch.nn.Parameter] = []
self.param_size: int = 0
self.shared_rid: int = self.r_id
self.param_num: int = 0
self.grad_num: int = 0
self.fp16_data = None
self.fp32_data = None
self.cpu_grad = None
self.temp_fp32_data = None
self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict()
self.need_offload: bool = False
self.is_syn: bool = False
self.nodes: List[Node] = []
self.fwd_prefetch_region = None
self.bwd_prefetch_region = None
self.in_mem_pool_flag: bool = False
@property
def can_release(self) -> bool:
"""
Check if the region can be released.
"""
return self.grad_num == self.param_num
@property
def has_inf_or_nan(self) -> bool:
"""
Check if the grad of the region has inf or nan values on CUDA.
"""
return torch.isinf(self.fp16_data).any() | torch.isnan(self.fp16_data).any()
def init_param_data(self, pre_alloc_tensor: torch.Tensor = None):
"""
Map the parameters in the region to a contiguous memory space.
"""
self.fp16_data = torch.zeros(
self.param_num, dtype=torch.half, device='cuda')
offset = 0
for param in self.fp16_params:
param.data = param.data.cuda()
p_num = param.data.numel()
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
param.data = self.fp16_data[offset:offset +
p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num)
offset += p_num
self.fp32_data = self.fp16_data.float().cpu().pin_memory()
free_storage(self.fp16_data)
if self.in_mem_pool_flag and pre_alloc_tensor is not None:
self.fp16_data = pre_alloc_tensor
def move_param_to_cuda(self):
"""
Move parameters from CPU to GPU.
It first moves float32 parameters to GPU and
then transforms float32 parameters to half-precision on the GPU.
The reason is that the performance of precision conversion on the CPU
is much slower than the data transfer overhead.
"""
self.temp_fp32_data.copy_(self.fp32_data, non_blocking=True)
self.temp_fp32_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
alloc_storage(self.fp16_data)
self.fp16_data[:self.param_num].copy_(self.temp_fp32_data)
self.fp16_data.record_stream(torch.cuda.current_stream())
self.__update_params_ptr()
def move_grad_to_cpu(self):
"""
Move gradients from GPU to CPU.
"""
self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True)
self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True)
self.fp16_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
self.free_cuda_data()
self.grad_num = 0
def free_cuda_data(self):
free_storage(self.fp16_data)
# torch.cuda.empty_cache()
def copy_grad_to_region_slice(self, param: torch.nn.Parameter, data_slice: torch.Tensor) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the region.
Args:
param (torch.nn.Parameter): the param used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the region
"""
begin, end = self.param_to_range[param]
self.fp16_data[begin:end].copy_(data_slice.data.flatten())
param.data = self.fp16_data[begin:end].view(param.data.shape)
self.grad_num += data_slice.numel()
def split(self, cut_node_idx: int, cut_param_idx: int):
"""
Split the region into two and return the latter.
"""
new_reg = Region(r_id=self.r_id + 1)
new_reg.nodes = self.nodes[cut_node_idx:]
new_reg.fp16_params = self.fp16_params[cut_param_idx:]
for p in new_reg.fp16_params:
new_reg.param_size += p.data.numel() * p.data.element_size()
new_reg.param_num += p.data.numel()
self.nodes = self.nodes[:cut_node_idx]
self.fp16_params = self.fp16_params[:cut_param_idx]
self.param_size -= new_reg.param_size
self.param_num -= new_reg.param_num
return new_reg
def __update_params_ptr(self) -> None:
for param in self.fp16_params:
begin, end = self.param_to_range[param]
param.data = self.fp16_data[begin:end].view(param.data.shape)

View File

@ -0,0 +1,526 @@
from typing import List, Any, Dict, Tuple
import torch
from torch.fx import Graph, Node
from .solver import SolverFactory
from .training_simulator import TrainingSimulator
from .region import Region
from .util import NodeInfo
class RegionManager:
"""
RegionManager is used to construct and manage the offload plan for the model execution.
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
solver_name (str): a solver name which specifies the preferences for plan searching.
memory_budget (float): the given memory budget.
cnode (List[str], optional): Common node List, should be the subset of input.
"""
def __init__(self,
graph: Graph,
solver_name: str = 'asyn',
memory_budget: float = -1.0,
cnode: List[str] = None):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.cnode = cnode
self.only_param_ops = []
self.param_region_map: Dict[torch.nn.Parameter, Region] = dict()
self.shared_region_pairs: List[Tuple[Region, Region]] = list()
self.region_list: List[Region] = list()
self.rid_in_pool: List[int] = list()
self.mem_block_size: int = 0
self.memory_budget = memory_budget
self.solver_name = solver_name
self.require_pool: bool = solver_name == 'asyn'
self.reg_to_block: Dict[int, int] = dict()
def _build_regions(self):
"""
1. Pre-processing, mainly contains linearized computing graph and
merge smaller regions into larger ones.
2. Construct a solver to search for an efficient offload strategy.
3. Post-processing, mainly contains early region placement if using asynchronous mode,
and initialize region data.
"""
self._pre_process()
solver_cls = SolverFactory.create(self.solver_name)
solver = solver_cls(self.region_list, self.memory_budget)
solver._call_solver()
self._post_process(solver.best_ts)
def _pre_process(self):
init_region_list = self._linearize_graph()
if len(self.shared_region_pairs) > 1:
raise NotImplementedError(
'The current version only considers at most one pair of parameter sharing.')
elif len(self.shared_region_pairs) == 1:
shared_regs = self.shared_region_pairs[0]
assert shared_regs[0].shared_rid == shared_regs[1].r_id \
and shared_regs[1].shared_rid == shared_regs[0].r_id
fst_id = shared_regs[0].r_id
lst_id = shared_regs[1].r_id
regs_left_out = init_region_list[:fst_id + 1]
regs_right_out = init_region_list[lst_id:]
hold_regs = init_region_list[fst_id + 1:lst_id]
else:
regs_left_out = []
regs_right_out = []
hold_regs = init_region_list
self.mem_block_size = self._search_block_size(hold_regs)
hold_regs = self._merge_small_regions(hold_regs)
if self.require_pool:
for reg in hold_regs:
reg.in_mem_pool_flag = True
self.rid_in_pool.append(reg.r_id)
self.region_list.extend(regs_left_out)
self.region_list.extend(hold_regs)
for reg in regs_right_out:
reg.r_id = self.region_list[-1].r_id + 1
self.region_list[reg.shared_rid].shared_rid = reg.r_id
self.region_list.append(reg)
self._process_shared_region()
self.max_param_num = max([reg.param_num for reg in self.region_list])
self.memory_budget -= self.max_param_num * torch.tensor([], dtype=torch.float32).element_size()
def _post_process(self, ts: TrainingSimulator = None):
if self.require_pool:
self._early_region_placement(ts)
self._init_region_data()
def _early_region_placement(self, ts: TrainingSimulator):
"""
Implemented the early region placement strategy to avoid GPU memory fragmentation.
It maps all region data into a contiguous memory space and
reuses the same memory space for regions that do not coexist.
Args:
ts (TrainingSimulator): the best training simulator, which records region execution flow.
Raises:
NotImplementedError: due to the naive implementation,
it may not find a suitable region placement strategy for the given execution flow.
"""
reg_flow = torch.cat(
[ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
mem_block_num = torch.max(
torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
coexist_matrix = torch.logical_or(
ts.fwd_reg_flow, ts.bwd_reg_flow)
block_to_regs = {}
for block_idx in range(mem_block_num):
block_to_regs[block_idx] = []
for reg in self.region_list:
if reg.r_id in self.rid_in_pool:
cur_reg_appears = coexist_matrix[:, reg.r_id]
cur_reg_coexists = torch.sum(
coexist_matrix[cur_reg_appears], dim=0).bool()
for block_idx in range(mem_block_num):
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
block_to_regs[block_idx].append(reg.r_id)
self.reg_to_block[reg.r_id] = block_idx
break
if reg.r_id not in self.reg_to_block:
raise NotImplementedError(
f'can not find a block from the memory pool to store parameters of the region')
self.memory_pool = torch.chunk(torch.zeros(int(
mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
"""
Merge smaller regions into larger ones for better bandwidth utilization and easier management.
It is inspired by Gemini.
Args:
orig_reg_list (List[Region]): original region list.
Returns:
List[Region]: region list after merging.
"""
r_id = orig_reg_list[0].r_id
region = Region(r_id=r_id)
region_list = [region]
for orig_reg in orig_reg_list:
if region_list[-1].param_size + orig_reg.param_size > self.mem_block_size:
r_id += 1
region = Region(r_id=r_id)
region_list.append(region)
region.param_size += orig_reg.param_size
region.param_num += orig_reg.param_num
region.nodes.extend(orig_reg.nodes)
region.fp16_params.extend(orig_reg.fp16_params)
self.__update_param_region_map(orig_reg.fp16_params, region)
return region_list
def _search_block_size(self,
region_list: List[Region],
search_interval_byte: int = 1024,
search_range_byte: int = 128 * 1024 ** 2) -> int:
"""
Search for a suitable memory block size.
Args:
region_list (List[Region]): region list.
search_interval_byte (int): searching interval in byte.
search_range_byte (int): searching range in byte.
Returns:
int: the best memory block size.
"""
def _get_wasted_mem(size_list: List[int], blk_size: int):
"""
Get wasted byte for a certain block size.
"""
acc_wasted = 0
left = 0
for s in size_list:
if left + s > blk_size:
acc_wasted += blk_size - left
left = s
left += s
acc_wasted += blk_size - left
return acc_wasted
param_size_list = [
region.param_size for region in region_list if region.r_id == region.shared_rid]
start_size = max(param_size_list)
min_mem_waste = float('+inf')
best_block_size = start_size
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
temp_waste = 0
temp_waste += _get_wasted_mem(param_size_list, block_size)
if temp_waste < min_mem_waste:
min_mem_waste = temp_waste
best_block_size = block_size
return best_block_size
def _init_region_data(self):
"""
Initialize region data, which maps the parameters in the region to a contiguous memory space.
"""
self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32)
for region in self.region_list:
pre_alloc_tensor = None
if self.require_pool and region.r_id in self.rid_in_pool:
block_idx = self.reg_to_block[region.r_id]
pre_alloc_tensor = self.memory_pool[block_idx]
if region.r_id <= region.shared_rid:
region.init_param_data(pre_alloc_tensor)
else:
shared_region = self.region_list[region.shared_rid]
region.fp16_data = shared_region.fp16_data
region.fp32_data = shared_region.fp32_data
region.param_to_range = shared_region.param_to_range
region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
)
torch.cuda.empty_cache()
def _process_shared_region(self):
"""
Special processing for the shared region, which uses GPT2 and Bert case as a priori knowledge.
"""
if len(self.shared_region_pairs):
assert len(self.shared_region_pairs) <= 1
former_reg, latter_reg = self.shared_region_pairs[0]
assert latter_reg.param_num >= former_reg.param_num
embedding_node = former_reg.nodes[-1]
assert embedding_node.op == 'call_module' and isinstance(
self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding)
if latter_reg.param_num > former_reg.param_num:
for idx, n in enumerate(latter_reg.nodes):
if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target),
torch.nn.Linear)) or \
(n.op == 'call_function' and n.target is torch.nn.functional.linear):
cut_node_idx = idx + 1
break
assert len(latter_reg.fp16_params) == 2
new_reg = latter_reg.split(cut_node_idx, 1)
for p in new_reg.fp16_params:
self.param_region_map[p] = new_reg
self.region_list.insert(new_reg.r_id, new_reg)
for reg in self.region_list[new_reg.r_id + 1:]:
reg.r_id += 1
latter_reg.shared_rid = former_reg.r_id
former_reg.shared_rid = latter_reg.r_id
def _linearize_graph(self) -> List[Region]:
"""Linearizing the graph
Args:
graph (Graph): The computing graph to be optimized.
Returns:
List[Region]: each region contains the actual 'node' in linearized manner.
Remarks:
Do merge the inplace ops and shape-consistency ops into the previous node.
"""
# List of target name that could be seen as common node
common_ops = ["getattr", "getitem", "size"]
def _is_cop(target: Any) -> bool:
"""Check if an op could be seen as common node
Args:
target (Any): node target
Returns:
bool
"""
if isinstance(target, str):
return target in common_ops
else:
return target.__name__ in common_ops
def _is_act(data: Any) -> bool:
"""Check if an op could be seen as parameter computation start
Args:
data (Any): meta_data
Returns:
bool
"""
label = False
if isinstance(data, torch.Tensor):
return True
elif isinstance(data, (tuple, list)):
for d in data:
label = label or _is_act(d)
return label
def _maybe_param_comp_start() -> bool:
"""Check if an op could be seen as parameter computation start
Args:
n (Node): node
Returns:
bool
"""
label = False
if n.op == "get_attr":
label = True
elif n.op == "call_module":
target = n.target
submod = self.root_module.get_submodule(target)
if (
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
return label and not sum([v for _, v in param_op_deps.items()])
def _is_param_comp_end() -> bool:
"""Check if an op could be seen as parameter computation end
Args:
n (Node): node
Returns:
bool
"""
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(
n.target), "inplace", False)
return inplace
label = False
if n.op == "call_module":
target = n.target
submod = self.root_module.get_submodule(target)
if (
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
elif n.op == "call_function":
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes))
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
def _exception_node_handling():
# TODO meta info prop bug
if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2:
n.meta['fwd_out'] = []
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
self.cnode = []
node_id = 0
region_id = 0
param_op_deps = {}
deps = {}
region_list = []
region = Region(r_id=region_id)
act_n = None
for n in self.graph.nodes:
if n.op != "placeholder" and n.op != "output":
for n_par in n.all_input_nodes:
if n_par.op != "placeholder" and n_par.name not in self.cnode:
deps[n_par] -= 1
if n_par.op != "placeholder" and n_par.name in self.only_param_ops:
param_op_deps[n_par] -= 1
if act_n in region.nodes and _maybe_param_comp_start():
ns = []
border_n_idx = region.nodes.index(act_n)
if border_n_idx < len(region.nodes):
ns = region.nodes[border_n_idx + 1:]
region.nodes = region.nodes[:border_n_idx + 1]
region_list.append(region)
region_id += 1
region = Region(r_id=region_id)
region.nodes = ns
_exception_node_handling()
region.nodes.append(n)
self._set_node_and_region_info(node_id, n, region)
node_id += 1
# if the node could free all dependencies in graph
# we could begin a new region
if _is_param_comp_end():
region_list.append(region)
region_id += 1
region = Region(r_id=region_id)
# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len(
[user for user in n.users if user.op != "output"])
# propagate param node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
]) or n.op == "get_attr":
self.only_param_ops.append(n.name)
param_op_deps[n] = len(
[user for user in n.users if user.op != "output"])
# record last activation node
if _is_act(n._meta_data):
act_n = n
if len(region.nodes):
region_list.append(region)
return region_list
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
cur_n.node_info = NodeInfo(node_id)
if cur_n.op == 'call_module':
target = cur_n.target
submod = self.root_module.get_submodule(target)
for p in list(submod.parameters(recurse=False)):
if p in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[p].r_id
self.param_region_map[p].shared_rid = cur_reg.r_id
self.shared_region_pairs.append(
(self.param_region_map[p], cur_reg))
else:
self.param_region_map[p] = cur_reg
cur_reg.fp16_params.append(p)
cur_reg.param_num += p.data.numel()
cur_reg.param_size += p.data.numel() * p.data.element_size()
elif cur_n.op == "get_attr":
attr_itr = self.root_module
atoms = cur_n.target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.Parameter):
if attr_itr in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
self.shared_region_pairs.append(
(self.param_region_map[attr_itr], cur_reg))
else:
self.param_region_map[attr_itr] = cur_reg
cur_reg.fp16_params.append(attr_itr)
cur_reg.param_num += attr_itr.data.numel()
cur_reg.param_size += attr_itr.data.numel() * attr_itr.data.element_size()
def get_region(self, param: torch.nn.Parameter) -> Region:
"""
Return the region owning the parameter.
Args:
param (torch.nn.Parameter): a torch parameter object
"""
return self.param_region_map[param]
def __update_param_region_map(self, params: List[torch.nn.Parameter], region: Region):
for p in params:
self.param_region_map[p] = region

View File

@ -0,0 +1,253 @@
from typing import List
import torch
from torch.fx.node import Node
from .region import Region
from .util import GlobalRuntimeInfo, requires_upload_p_in_fwd
class SynPreFwdPostBwdOP(torch.autograd.Function):
"""
A customized prefetch and offload operation.
Args:
input_: input tensor.
fwd_info: information dict, which contains region indices
that need to be uploaded or freed during forward pass.
bwd_info: information dict, which contains region indices
that need to be uploaded during backward pass.
"""
@staticmethod
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
d2h_rid = fwd_info.get('d2h_rid', None)
if d2h_rid is not None:
free_region = GlobalRuntimeInfo.region_list[d2h_rid]
assert isinstance(free_region, Region)
free_region.free_cuda_data()
h2d_rid = fwd_info.get('h2d_rid', None)
if h2d_rid is not None:
h2d_region = GlobalRuntimeInfo.region_list[h2d_rid]
assert isinstance(h2d_region, Region)
h2d_region.move_param_to_cuda()
return input_
@staticmethod
def backward(ctx, grad_output):
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
assert isinstance(pref_region, Region)
pref_region.move_param_to_cuda()
return grad_output, None, None
class AsynPreFwdPostBwdOP(torch.autograd.Function):
"""
A customized prefetch and offload operation.
Args:
input_: input tensor.
fwd_info: information dict, which contains region indices
that need to be prefetched, waited, or freed during forward pass.
bwd_info: information dict, which contains region indices
that need to be prefetched or waited during backward pass.
"""
@staticmethod
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
sync_rid = fwd_info.get('sync_rid', None)
if sync_rid is not None:
prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get(
sync_rid, None)
if prefetch_event:
prefetch_event.wait()
h2d_rid = fwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
assert isinstance(pref_region, Region)
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
pref_region.move_param_to_cuda()
prefetch_event = torch.cuda.Event()
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
GlobalRuntimeInfo.fwd_prefetch_event_map[h2d_rid] = prefetch_event
return input_
@staticmethod
def backward(ctx, grad_output):
sync_rid = ctx.bwd_info.get('sync_rid', None)
if sync_rid is not None:
wait_region = GlobalRuntimeInfo.region_list[sync_rid]
assert isinstance(wait_region, Region)
prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get(
sync_rid, None)
if prefetch_event:
prefetch_event.wait()
else:
wait_region.move_param_to_cuda()
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
assert isinstance(pref_region, Region)
master_stream = torch.cuda.current_stream()
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
pref_region.move_param_to_cuda()
prefetch_event = torch.cuda.Event()
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
GlobalRuntimeInfo.bwd_prefetch_event_map[h2d_rid] = prefetch_event
return grad_output, None, None
def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
'''
Convert Upload and Offload operation into runtime action.
Argument:
tensor(torch.Tensor): input tensor.
fwd_info(dict): information dict, which contains region indices
that need to be uploaded, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be uploaded during backward pass.
'''
with torch._C.DisableTorchFunction():
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
'''
Convert Prefetch and Offload operation into runtime action.
Argument:
tensor(torch.Tensor): input tensor.
fwd_info(dict): information dict, which contains region indices
that need to be prefetched, waited, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be prefetched or waited during backward pass.
'''
with torch._C.DisableTorchFunction():
ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
def replace_node_users(orig_node: Node, inserted_node: Node, rep_user_nodes: List[Node] = None):
user_list = list(orig_node.users.keys())
if rep_user_nodes is not None:
user_list = rep_user_nodes
for user in user_list:
if user == inserted_node:
continue
new_args = list(user.args)
new_kwargs = dict(user.kwargs)
# the origin node may be a positional argument or key word argument of user node
if orig_node in new_args:
# substitute the origin node with offload_apply_node
new_args[new_args.index(orig_node)] = inserted_node
user.args = tuple(new_args)
elif str(orig_node) in new_kwargs:
# substitute the origin node with offload_apply_node
new_kwargs[str(orig_node)] = inserted_node
user.kwargs = new_kwargs
def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]):
"""
This pass is used to add the synchronous upload and offload spec apply node to the origin graph.
"""
mod_graph = gm.graph
last_inp_node = tuple(mod_graph.nodes)[0]
for r_idx, region in enumerate(region_list):
# forward upload
fwd_info = {}
if requires_upload_p_in_fwd(region_list[region.shared_rid]):
fwd_info['h2d_rid'] = region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1
bwd_info = {}
# backward upload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
return gm
def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]):
"""
This pass is used to add the asynchronous prefetch and offload spec apply node to the origin graph.
"""
mod_graph = gm.graph
# upload parameters of the first region
last_inp_node = tuple(mod_graph.nodes)[0]
first_region_with_p = [
region for region in region_list if region.param_size][0]
fwd_info = {"h2d_rid": first_region_with_p.r_id}
with mod_graph.inserting_after(last_inp_node):
upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, {}))
replace_node_users(last_inp_node, upload_apply_node)
last_inp_node = upload_apply_node
for r_idx, region in enumerate(region_list):
# forward prefetch
fwd_info = {}
if region.param_size:
fwd_info['sync_rid'] = region.r_id
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]):
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx-1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1
bwd_info = {}
# backward prefetch
if r_idx > 0 and region_list[r_idx-1].need_offload:
bwd_info['sync_rid'] = r_idx - 1
if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region:
bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
if region.bwd_prefetch_region:
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, {}, bwd_info))
replace_node_users(last_inp_node, new_node)
# gm.graph.print_tabular()
return gm

View File

@ -0,0 +1,523 @@
import time
from typing import List, Dict, Type
from abc import ABC, abstractmethod
NOT_NVML = False
try:
from pynvml import *
except:
NOT_NVML = True
import torch
from torch.fx.node import Node
from colossalai.utils.cuda import get_current_device
from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator
from .region import Region
from .util import NodeInfo, NvDevicePower
def benchmark_func(func, number=1, repeat=1, warmup=3):
"""
benchmark data transfer cost.
"""
for i in range(warmup):
func()
costs = []
for i in range(repeat):
torch.cuda.synchronize()
begin = time.time()
for i in range(number):
func()
torch.cuda.synchronize()
costs.append((time.time() - begin) / number)
return sum(costs) / len(costs)
class Solver(ABC):
"""
The parameter offload solver.
Args:
region_list (List[Region]): represents the linearized DNN computing graph.
memory_budget (float): the given memory budget.
error_factor (float): the error factor.
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
"""
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0,
error_factor: float = 0.95) -> None:
self.region_list = region_list
self.error_factor: float = error_factor
if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor
else:
self.memory_budget = torch.cuda.get_device_properties(
get_current_device()).total_memory * self.error_factor
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power()
@abstractmethod
def _call_solver(self):
raise NotImplementedError
@abstractmethod
def _try_to_offload(self, *args):
raise NotImplementedError
@abstractmethod
def _eval_one_choice(self, *args):
raise NotImplementedError
def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: float, extra_cost: float):
"""
Compute the profits of the offload strategies,
which packages the memory savings information for subsequent comparisons.
Args:
total_mem_saving (float): the total memory saving of the offload strategy.
peak_mem_saving (float): the peak memory saving of the offload strategy.
extra_cost (float): extra data transfer cost.
Returns:
tuple: profit information, the first term represents memory savings per unit of time.
"""
if extra_cost == 0:
# means data transfer overhead can be completely overlapped
return (float('inf'), total_mem_saving, peak_mem_saving)
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
"""
Compare the profits of the two offload strategies using the dictionary order algorithm.
Args:
profit_a (tuple): the profit of a offload strategy.
profit_b (tuple): the profit of another offload strategy.
Returns:
bool: whether profit_a is greater than profit_b.
"""
for val1, val2 in zip(profit_a, profit_b):
if val1 != val2:
return val1 > val2
return False
def _update_state(self, best_ts: TrainingSimulator):
"""
Update the solver state.
"""
self.best_ts = best_ts
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
def _update_node_mem_info(self,
fwd_mem_info: Dict[Node, float],
bwd_mem_info: Dict[Node, float]):
"""
Update the runtime memory information of the node.
Args:
fwd_mem_info (Dict[Node, float]): the runtime memory of each node in forward pass.
bwd_mem_info (Dict[Node, float]): the runtime memory of each node in backward pass.
"""
for node, mem in fwd_mem_info.items():
assert hasattr(node, 'node_info') and isinstance(
node.node_info, NodeInfo)
node.node_info.runtime_fwd_mem = mem
for node, mem in bwd_mem_info.items():
assert hasattr(node, 'node_info') and isinstance(
node.node_info, NodeInfo)
node.node_info.runtime_bwd_mem = mem
def _extract_computing_power(self):
"""
return the FP16 computing performance of the current NVIDIA GPU.
Raises:
TypeError: Unknown NVIDIA GPU device.
"""
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
device_name = nvmlDeviceGetName(handle)
units = 1e12
if device_name.__contains__("RTX 3080"):
return NvDevicePower.RTX3080_FP16 * units
elif device_name.__contains__("RTX 3090"):
return NvDevicePower.RTX3090_FP16 * units
elif device_name.__contains__('V100'):
return NvDevicePower.V100_FP16 * units
elif device_name.__contains__("A100"):
return NvDevicePower.A100_FP16 * units
else:
raise TypeError(f'Unknown NVIDIA GPU device name {device_name}')
def _profile_bandwidth(self):
"""
Profile the bidirectional communication bandwidth between CPU and GPU
using data volumes ranging from 1KB to 1GB.
"""
print('profiling bandwidth ......')
link_to_bandwidth = {}
links = ['h2d', 'd2h']
for link in links:
t_size = 1024
size_to_bandwidth = {}
# from 1KB to 1GB
for i in range(21):
if link == 'h2d':
src_tensor = torch.ones(
int(t_size), dtype=torch.int8, pin_memory=True)
dst_tensor = torch.ones(
(int(t_size)), dtype=torch.int8, device='cuda')
elif link == 'd2h':
src_tensor = torch.ones(
int(t_size), dtype=torch.int8, device='cuda')
dst_tensor = torch.ones(
(int(t_size)), dtype=torch.int8, pin_memory=True)
def func():
dst_tensor.copy_(src_tensor)
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
print(f'size: {t_size / 1024 ** 2:.3f} MB, '
f'{src_tensor.device.type}-to-{dst_tensor.device.type} '
f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s')
t_size *= 2
link_to_bandwidth[link] = size_to_bandwidth
return link_to_bandwidth
class SynGreedySolver(Solver):
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0) -> None:
super().__init__(region_list, memory_budget)
self.best_ts: SynTrainingSimulator = None
self._init_state()
def _init_state(self):
"""
Initialize the solver state when without offloading.
"""
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
self._update_state(ts)
def _call_solver(self):
"""
Call the solver to search an efficient parameter offloading strategy for the linearized graph.
The solver adopts greedy algorithm.
Raises:
NotImplementedError: Unable to find a solution for the given memory budget.
"""
print("search offloading strategy ......")
while self.best_ts.peak_mem > self.memory_budget:
offload_region = None
best_ts = None
max_profit = (0,)
# search which region should be offloaded,
# the last region does not need to be offloaded.
for region in self.region_list[:-1]:
if region.param_size and not region.need_offload:
temp_ts, profit = self._try_to_offload(region)
if self._compare_profit(profit, max_profit):
offload_region = region
max_profit = profit
best_ts = temp_ts
if offload_region is not None and best_ts is not None:
offload_region.need_offload = True
offload_region.is_syn = True
self._update_state(best_ts)
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
def _call_solver_l2l(self):
"""
The layer-wise offload strategy.
"""
for region in self.region_list[:-1]:
region.need_offload = True
region.is_syn = True
def _try_to_offload(self, offload_region: Region):
# record previous information
orig_need_offload = offload_region.need_offload
assert not orig_need_offload
offload_region.need_offload = True
ts, profit = self._eval_one_choice(offload_region)
# restore previous information
offload_region.need_offload = orig_need_offload
return ts, profit
def _eval_one_choice(self, offload_region: Region):
"""
Evaluate the profit of a strategy choice.
Args:
offload_region (Region): the offload region of current choice.
Returns:
SynTrainingSimulator: the training simulator corresponding to the current strategy.
tuple: contains memory saving and cost information of the current strategy.
"""
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
extra_comm_cost = 2.0 * \
ts._get_communication_overhead('h2d', offload_region.param_size)
# the shared region needs to be moved twice
if offload_region.r_id < offload_region.shared_rid:
extra_comm_cost *= 2.0
profit = self._compute_offload_profit(
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class AsynGreedySolver(Solver):
def __init__(self,
region_list: List[Region],
memory_budget: float = -1.0,
search_window_size: int = 3):
super().__init__(region_list, memory_budget)
self.search_window_size = search_window_size
# Records the prefetch execution location of the offloaded region
self.region_to_region_map = {}
self.best_ts: AsynTrainingSimulator = None
self._init_state()
def _init_state(self):
"""
Initialize the solver state when without offloading.
"""
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
self._update_state(ts)
print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB")
def _call_solver(self):
"""
Call the solver to search an efficient parameter offloading strategy for the linearized graph.
The solver adopts greedy algorithm.
Raises:
NotImplementedError: Unable to find a solution for the given memory budget.
"""
print("search for offloading strategy ......")
# Records the prefetch execution location of the offloaded region
region_to_region_map = {}
while self.best_ts.peak_mem > self.memory_budget:
region_to_offload = None
max_offload_profit = (0,)
best_offl_ts = None
# search which region should be offloaded,
# the last region does not need to be offloaded
for region in self.region_list[:-1]:
if region.param_size and not region.need_offload:
max_prefetch_profit = (0,)
best_pref_ts = None
# search when to prefetch the region offloaded
for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]:
if host_region.bwd_prefetch_region is not None:
continue
temp_ts, profit = self._try_to_offload(
host_region, region)
if self._compare_profit(profit, max_prefetch_profit):
region_to_region_map[region.r_id] = host_region
max_prefetch_profit = profit
best_pref_ts = temp_ts
if profit[0] == float('inf'):
break
if self._compare_profit(max_prefetch_profit, max_offload_profit):
region_to_offload = region
max_offload_profit = max_prefetch_profit
best_offl_ts = best_pref_ts
if (region_to_offload is not None) and (best_offl_ts is not None):
region_to_offload.need_offload = True
if region_to_region_map[region_to_offload.r_id] == region_to_offload:
region_to_offload.is_syn = True
else:
region_to_region_map[region_to_offload.r_id].bwd_prefetch_region = region_to_offload
self.region_to_region_map[region_to_offload.r_id] = region_to_region_map[region_to_offload.r_id]
self._update_state(best_offl_ts)
elif self.region_to_region_map.__len__() > 0:
self._repair_strategy()
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
region_to_region_map.clear()
def _try_to_offload(self, host_region: Region, offload_region: Region):
"""
Attempts to offload the region and prefetch it in backward pass.
"""
# record previous information
orig_prefetch = host_region.bwd_prefetch_region
orig_is_syn = offload_region.is_syn
orig_need_offload = offload_region.need_offload
if host_region == offload_region:
offload_region.is_syn = True
else:
host_region.bwd_prefetch_region = offload_region
offload_region.need_offload = True
ts, profit = self._eval_one_choice()
# restore previous information
host_region.bwd_prefetch_region = orig_prefetch
offload_region.is_syn = orig_is_syn
offload_region.need_offload = orig_need_offload
return ts, profit
def _try_convert_to_syn_upload(self, host_region: Region, offload_region: Region):
"""
Attempts to convert asynchronous prefetch into synchronous upload operations.
"""
# record previous information
orig_prefetch = host_region.bwd_prefetch_region
orig_is_syn = offload_region.is_syn
assert orig_prefetch is not None and not orig_is_syn
host_region.bwd_prefetch_region = None
offload_region.is_syn = True
ts, profit = self._eval_one_choice()
# restore previous information
host_region.bwd_prefetch_region = orig_prefetch
offload_region.is_syn = orig_is_syn
return ts, profit
def _repair_strategy(self):
"""
Repair offload strategy.
It attempts to convert asynchronous prefetch into synchronous upload operations and selects the best one.
The repair process does not end until peak memory is reduced or there is no asynchronous prefetch operation.
"""
print("repair strategy ......")
peak_mem_saving = 0
while len(self.region_to_region_map) and peak_mem_saving <= 0:
max_profit = (0,)
best_ts = None
undo_host_region = None
undo_offload_region = None
for offload_region_id, host_region in self.region_to_region_map.items():
offload_region = self.region_list[offload_region_id]
assert host_region.bwd_prefetch_region == offload_region
assert offload_region.need_offload
assert not offload_region.is_syn
ts, profit = self._try_convert_to_syn_upload(host_region,
offload_region)
if self._compare_profit(profit, max_profit):
undo_host_region = host_region
undo_offload_region = offload_region
max_profit = profit
best_ts = ts
if best_ts is None:
raise NotImplementedError('repair error!')
assert not undo_offload_region.is_syn
undo_offload_region.is_syn = True
undo_host_region.bwd_prefetch_region = None
peak_mem_saving = self.best_ts.peak_mem - best_ts.peak_mem
self._update_state(best_ts)
self.region_to_region_map.pop(undo_offload_region.r_id)
return best_ts
def _eval_one_choice(self):
"""
Evaluate the profit of a strategy choice.
Returns:
AsynTrainingSimulator: the training simulator corresponding to the current strategy.
tuple: contains memory saving and cost information of the current strategy.
"""
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
profit = self._compute_offload_profit(
ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class SolverFactory:
solvers: Dict[str, Type[Solver]] = {
'syn': SynGreedySolver,
'asyn': AsynGreedySolver
}
@staticmethod
def create(solver_name: str) -> Type[Solver]:
if solver_name not in SolverFactory.solvers:
raise TypeError(f"Unknown parameter offload policy {solver_name}")
return SolverFactory.solvers[solver_name]
@staticmethod
def get_solver_names():
return tuple(SolverFactory.solvers.keys())

View File

@ -0,0 +1,458 @@
import bisect
from typing import List, Dict
from collections import OrderedDict
from abc import ABC, abstractmethod
from torch.fx.node import Node
from .region import Region
from .util import *
@dataclass
class ExecutionPeriod:
start_time: float = 0
end_time: float = 0
class TrainingSimulator(ABC):
"""
The Training Simulator is used to simulate the training process.
It records computation, communication, and runtime memory during forward and backward passes.
Args:
region_list (List[Region]): represents the linearized DNN computing graph.
comp_power (float): the NVIDIA GPU FP16 compuing power.
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
"""
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
self.region_list = region_list
self.region_num = len(region_list)
self.runtime_mem: int = 0
self.peak_mem: int = 0
self.total_mem_saving: int = 0
self.fwd_node_mem: Dict[Node, float] = {}
self.bwd_node_mem: Dict[Node, float] = {}
# Node dependencies in backward pass
self.bwd_node_deps: Dict[Node, int] = {}
self.comp_power: float = comp_power
self.link_to_bandwidth: Dict[str, Dict[float, float]] = link_to_bw
@abstractmethod
def execute(self):
raise NotImplementedError
@abstractmethod
def _eval_fwd_mem_per_region(self, region: Region):
raise NotImplementedError
@abstractmethod
def _eval_bwd_mem_per_region(self, region: Region):
raise NotImplementedError
def _get_bandwidth(self, link: str, comm_volumn: float) -> float:
"""
Get the data transfer bandwidth.
Args:
link (str): the data transfer link.
comm_volumn (float): the amount of data transferred.
Returns:
float: the data transfer bandwidth.
"""
assert len(self.link_to_bandwidth)
if link not in self.link_to_bandwidth:
raise TypeError(f"Unknown data transfer link {link}")
# size_list = sorted(list(map(float, self.link_to_bandwidth[link].keys())))
size_list = sorted(self.link_to_bandwidth[link].keys())
d_idx = bisect.bisect_left(size_list, comm_volumn)
return self.link_to_bandwidth[link][size_list[d_idx]]
def _get_communication_overhead(self, link: str, comm_volumn: float) -> float:
return comm_volumn / self._get_bandwidth(link, comm_volumn)
def _get_computing_overhead(self, flop: float) -> float:
return flop / self.comp_power
class SynTrainingSimulator(TrainingSimulator):
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
def execute(self):
"""
Simulate synchronous training process.
"""
for reg in self.region_list:
self._eval_fwd_mem_per_region(reg)
for reg in self.region_list.__reversed__():
self._eval_bwd_mem_per_region(reg)
def _eval_fwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the forward execution reaches the current region.
"""
# upload parameters of the current region
if requires_upload_p_in_fwd(self.region_list[region.shared_rid]):
self.runtime_mem += region.param_size
for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + \
calculate_fwd_out(node)
self.fwd_node_mem[node] = self.runtime_mem
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
if region.need_offload:
self.runtime_mem -= region.param_size
def _eval_bwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the backward execution reaches the current region.
"""
# upload parameters of the current region
if region.need_offload:
self.runtime_mem += region.param_size
# add the gradient of the parameter
if region.r_id < region.shared_rid:
# gradient accumulation is required for shared parameters
self.runtime_mem += 2.0 * region.param_size
else:
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta['bwd_mem_tmp'] + \
node.meta['bwd_mem_out']
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
calculate_fwd_tmp(node))
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
for user_node in node.users:
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta['bwd_mem_out']
if self.runtime_mem < 0:
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!")
# release parameter and offload gradient in region
if region.r_id == region.shared_rid:
self.runtime_mem -= 2.0 * region.param_size
elif region.r_id < region.shared_rid:
self.runtime_mem -= 3.0 * region.param_size
elif self.region_list[region.shared_rid].need_offload:
self.runtime_mem -= region.param_size
class AsynTrainingSimulator(TrainingSimulator):
def __init__(self,
region_list: List[Region],
comp_power: float,
link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
self.iter_end_time: int = 0
# the last computation execution period
self.last_comp: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
# the last parameter prefetch execution period
self.last_h2d: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
# the last gradient offload execution period
self.last_d2h: ExecutionPeriod = ExecutionPeriod(
start_time=0, end_time=0)
# the forward computation execution period of the region
self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the forward parameter prefetch execution period of the region
self.fwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the backward computation execution period of the region
self.bwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the backward parameter prefetch execution period of the region
self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the gradient offload execution period of the region
# which is divided into those that are waiting and those that have been released
self.bwd_reg_to_offl_waiting: OrderedDict[int,
ExecutionPeriod] = OrderedDict()
self.bwd_reg_to_offl_freed: OrderedDict[int,
ExecutionPeriod] = OrderedDict()
# the region buffer, which records regions that are offloaded but not released
self.reg_buffer_to_free: List[int] = []
# node dependencies in backward pass
self.bwd_node_deps: Dict[Node, int] = {}
# the region execution flow,
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# when the execution reaches the i-th region.
self.fwd_reg_flow = torch.zeros(
(self.region_num, self.region_num)).bool()
self.bwd_reg_flow = torch.zeros(
(self.region_num, self.region_num)).bool()
def execute(self):
"""
Simulate asynchronous training process.
In forward pass, parameter prefetching is advanced by one region.
In backward pass, parameter prefetching is executed at the specified location,
and gradient offloading is urgent.
"""
for reg in self.region_list:
if reg.param_size and reg.r_id < self.region_num - 1:
for nr in self.region_list[reg.r_id + 1:]:
if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]):
reg.fwd_prefetch_region = nr
break
self._eval_fwd_cost_per_region(reg)
self._eval_fwd_mem_per_region(reg)
for reg in self.region_list.__reversed__():
self._eval_bwd_cost_per_region(reg)
self._eval_bwd_mem_per_region(reg)
# release remaining grads
for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items():
self.bwd_reg_to_offl_freed[reg_id] = offl_exec
self.runtime_mem -= self.region_list[reg_id].param_size
self.bwd_reg_to_offl_waiting.clear()
self.iter_end_time = max(
self.last_comp.end_time, self.last_d2h.end_time)
def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
"""
Insert parameter prefetch execution period of the current region to the end of the h2d stream
"""
pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time)
pref_end_time = pref_start_time + \
2.0 * self._get_communication_overhead('h2d', region.param_size)
pref_ep = ExecutionPeriod(
start_time=pref_start_time, end_time=pref_end_time)
if is_fwd:
self.fwd_reg_to_pref[region.r_id] = pref_ep
else:
self.bwd_reg_to_pref[region.r_id] = pref_ep
self.last_h2d = pref_ep
def _insert_comp_exec(self, region: Region, is_fwd: bool = True):
"""
Insert computation execution period of the current region to the end of the computing stream
"""
if is_fwd:
reg_to_comp = self.fwd_reg_to_comp
reg_to_pref = self.fwd_reg_to_pref
flop_key = 'fwd_flop'
else:
reg_to_comp = self.bwd_reg_to_comp
reg_to_pref = self.bwd_reg_to_pref
flop_key = 'bwd_flop'
comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(
region.r_id, ExecutionPeriod(0, 0)).end_time)
comp_end_time = comp_start_time + \
sum([self._get_computing_overhead(node.meta.get(flop_key, 0))
for node in region.nodes])
comp_ep = ExecutionPeriod(
start_time=comp_start_time, end_time=comp_end_time)
reg_to_comp[region.r_id] = comp_ep
self.last_comp = comp_ep
def _insert_d2h_exec(self, region: Region):
"""
Insert gradient offload execution period of the current region to the end of the d2h stream
"""
offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time)
offl_end_time = offl_start_time + \
self._get_communication_overhead('d2h', region.param_size)
offl_ep = ExecutionPeriod(
start_time=offl_start_time, end_time=offl_end_time)
self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep
self.last_d2h = offl_ep
def _eval_fwd_cost_per_region(self, region: Region):
"""
Evaluate computation and communication execution period of the region in forward pass.
"""
# upload parameters of the first region
if region.r_id == 0:
self._insert_h2d_exec(region)
# prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self._insert_h2d_exec(fwd_prefetch_region)
# execute computation
self._insert_comp_exec(region)
def _eval_fwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the forward execution reaches the current region.
"""
# upload parameters of the current region
if region.r_id <= 0:
self.runtime_mem += region.param_size
self.fwd_reg_flow[region.r_id, region.r_id] = True
else:
self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1]
self.fwd_reg_flow[region.r_id,
self.reg_buffer_to_free] = False
self.reg_buffer_to_free.clear()
# prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self.runtime_mem += fwd_prefetch_region.param_size
self.fwd_reg_flow[region.r_id,
fwd_prefetch_region.r_id] = True
for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + \
calculate_fwd_out(node)
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
self.fwd_node_mem[node] = self.runtime_mem
if region.need_offload:
self.runtime_mem -= region.param_size
assert len(
self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}'
self.reg_buffer_to_free.append(region.r_id)
def _eval_bwd_cost_per_region(self, region: Region):
"""
Evaluate computation and communication execution period of the region in backward pass.
"""
# upload parameters of the current region
if region.is_syn:
assert region.need_offload
self._insert_h2d_exec(region, is_fwd=False)
# prefetch parameters of the region choiced, which is parallel to computation
if region.bwd_prefetch_region is not None:
self._insert_h2d_exec(region.bwd_prefetch_region, is_fwd=False)
# execute computation
self._insert_comp_exec(region, is_fwd=False)
# offload gradient
if requires_offload_g_in_bwd(region):
self._insert_d2h_exec(region)
assert len(self.reg_buffer_to_free) == 0
for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items():
if offl_exec.end_time >= self.last_comp.start_time:
break
self.reg_buffer_to_free.append(reg_id)
self.bwd_reg_to_offl_freed[reg_id] = offl_exec
for reg_id in self.reg_buffer_to_free:
self.bwd_reg_to_offl_waiting.pop(reg_id)
def _eval_bwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the backward execution reaches the current region.
"""
if region.r_id + 1 < self.region_num:
self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1]
else:
self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1]
self.bwd_reg_flow[region.r_id,
self.reg_buffer_to_free] = False
# free gradients in the buffer
while len(self.reg_buffer_to_free):
reg_id = self.reg_buffer_to_free.pop(0)
self.runtime_mem -= self.region_list[reg_id].param_size
# upload parameters of the current region
if region.is_syn:
self.runtime_mem += region.param_size
self.bwd_reg_flow[region.r_id, region.r_id] = True
# prefetch parameters of the region choiced
bwd_prefetch_region = region.bwd_prefetch_region
if bwd_prefetch_region:
self.runtime_mem += bwd_prefetch_region.param_size
self.bwd_reg_flow[region.r_id,
bwd_prefetch_region.r_id] = True
# add the gradient of the parameter
if region.r_id < region.shared_rid:
# gradient accumulation is required for shared parameters
self.runtime_mem += 2.0 * region.param_size
else:
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta['bwd_mem_tmp'] + \
node.meta['bwd_mem_out']
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
calculate_fwd_tmp(node))
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
for user_node in node.users:
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta['bwd_mem_out']
if self.runtime_mem < 0:
raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!")
# release parameters of the region
if requires_release_p_in_bwd(self.region_list[region.shared_rid]):
self.runtime_mem -= region.param_size

View File

@ -0,0 +1,90 @@
from dataclasses import dataclass
from typing import List
import torch
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp
from .region import Region
@dataclass
class NodeInfo:
node_id: int = 0
runtime_fwd_mem: float = 0
runtime_bwd_mem: float = 0
class NvDevicePower:
"""
NVIDIA GPU computing performance (TFLOPs).
"""
RTX3080_FP16 = 70
RTX3080_FP32 = 34.1
RTX3090_FP16 = 71
RTX3090_FP32 = 35.7
V100_FP16 = 31.4
V100_FP32 = 15.7
A100_FP16 = 78
A100_FP32 = 19.5
class GlobalRuntimeInfo:
h2d_stream = torch.cuda.Stream()
d2h_stream = torch.cuda.Stream()
fwd_prefetch_event_map = {}
bwd_prefetch_event_map = {}
region_list = []
def compute_act_peak_mem(region_list: List[Region]) -> float:
act_peak_mem = 0
runtime_mem = 0
# forward
for region in region_list:
for node in region.nodes:
runtime_mem = runtime_mem + \
calculate_fwd_tmp(node) + calculate_fwd_out(node)
act_peak_mem = max(runtime_mem, act_peak_mem)
# backward
bwd_deps = {}
for region in region_list.__reversed__():
for node in region.nodes.__reversed__():
runtime_mem -= calculate_fwd_out(node)
runtime_mem = runtime_mem + \
node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out']
act_peak_mem = max(runtime_mem, act_peak_mem)
runtime_mem = runtime_mem - \
node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node)
# free bwd_mem_out
bwd_deps[node] = len(node.all_input_nodes)
for user_node in node.users:
if user_node in bwd_deps:
bwd_deps[user_node] -= 1
if bwd_deps[user_node] <= 0:
runtime_mem -= user_node.meta['bwd_mem_out']
return act_peak_mem
def compute_max_param_mem(region_list: List[Region]) -> float:
return max(region.param_size for region in region_list)
def compute_total_param_mem(region_list: List[Region]) -> float:
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid)
def requires_upload_p_in_fwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
def requires_release_p_in_bwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
def requires_offload_g_in_bwd(region: Region):
return region.param_size and (region.r_id <= region.shared_rid)

View File

@ -0,0 +1,37 @@
# Auto-Offload Demo with GPT2
## Requirements
Before you can launch training, you need to install the following requirements.
### Install PyTorch
```bash
#conda
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
#pip
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
```
### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website
```bash
pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
```
### Install transformers
```bash
pip install transformers
```
## Dataset
For simplicity, the input data is randonly generated here.
## Training
```bash
#Run the auto offload on GPT with default setting and a dummy dataset.
bash run.sh
```

View File

@ -0,0 +1,65 @@
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
class GPTLMModel(nn.Module):
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50257):
super().__init__()
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size))
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
def get_gpt2_components(model_type: str, batch_size: int):
vocab_size = 1024
seq_len = 8
def gpt2_model_builder():
if model_type == "gpt2_medium":
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16)
elif model_type == "gpt2_xl":
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32)
elif model_type == "gpt2_10b":
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16)
elif model_type == "gpt2_14b":
return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16)
elif model_type == "gpt2_20b":
return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16)
elif model_type == "gpt2_24b":
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16)
else:
raise TypeError(f"model_builder {model_type}")
def gpt2_data_gen(device="cuda"):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
return gpt2_model_builder, gpt2_data_gen

View File

@ -0,0 +1,2 @@
colossalai >= 0.1.12
torch >= 1.8.1

View File

@ -0,0 +1,8 @@
export BATCH_SIZE=${BATCH_SIZE:-64}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export MEMORY_BUDGET=${MEMORY_BUDGET:-16}
export SOLVER_TYPE=${SOLVER_TYPE:-"asyn"}
mkdir -p offload_logs
python train_gpt_offload.py --model_type=${MODEL_TYPE} --memory_budget=${MEMORY_BUDGET} --solver_type=${SOLVER_TYPE} --batch_size=${BATCH_SIZE} 2>&1 | tee ./offload_logs/${MODEL_TYPE}_bs_${BATCH_SIZE}_st_${SOLVER_TYPE}.log

View File

@ -0,0 +1,94 @@
import time
import pytest
import argparse
from functools import partial
import torch
from torch.utils._pytree import tree_map
import torch.multiprocessing as mp
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.fx.profiler import parameter_size
from colossalai.utils import free_port, get_current_device
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from model_zoo import get_gpt2_components, GPTLMLoss
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', type=str, default="gpt2_medium")
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--solver_type', type=str, default='asyn')
parser.add_argument('--memory_budget', type=float, default=16)
return parser.parse_args()
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def train_gpt(args):
memory_budget = args.memory_budget * 1024 * 1024 * 1024
solver_type = args.solver_type
model_type = args.model_type
batch_size = args.batch_size
# build model
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
criterion = GPTLMLoss()
start_time = time.time()
model = model_builder()
model.train()
param_size = parameter_size(model) / 1024 ** 2 / 2
init_time = time.time() - start_time
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
data_args = data_gen(device="cpu")
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
data_args = tree_map(wrap_fn, data_args)
start_time = time.time()
model = memory_optimize(model, data_args, memory_budget, solver_type)
solver_time = time.time() - start_time
print(f"solver_time={solver_time:.3f} s")
hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
optim = AMPOptimizer(hybrid_optimizer, model)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
time_list = []
data_args = data_gen(device="cuda")
data_args = tree_map(wrap_fn, data_args)
for step in range(10):
optim.zero_grad()
torch.cuda.synchronize()
start_time = time.time()
loss = criterion(model(**data_args), label)
optim.backward(loss)
torch.cuda.synchronize()
time_list.append(time.time() - start_time)
optim.step()
torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
print(f'solver_type: {solver_type} | model_type: {model_type}')
print(
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(time_list)
def run(rank, world_size, port, args):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
train_gpt(args)
if __name__ == '__main__':
args = parse_args()
run_func = partial(run, world_size=1, port=free_port(), args=args)
mp.spawn(run_func, nprocs=1)

View File

@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
from transformers import BertConfig, BertLMHeadModel
from tests.components_to_test.registry import non_distributed_component_funcs
class GPTLMModel(nn.Module):
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50257):
super().__init__()
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size))
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
class LMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
class BertLMModel(nn.Module):
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522):
super().__init__()
self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size,
num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size,
vocab_size=vocab_size))
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
@non_distributed_component_funcs.register(name='bert_')
def get_bert_components():
vocab_size = 1024
seq_len = 64
batchSize = 64
def bert_model_builder():
model = BertLMModel(hidden_size=8192, num_layers=4, num_attention_heads=32, vocab_size=vocab_size)
return model
def bert_data_gen(device="meta"):
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
return bert_model_builder, bert_data_gen
@non_distributed_component_funcs.register(name='gpt2_')
def get_gpt2_components():
vocab_size = 1024
seq_len = 8
batchSize = 64
def gpt2_model_builder():
model = GPTLMModel(hidden_size=8192, num_layers=2, num_attention_heads=32, vocab_size=vocab_size)
return model
def gpt2_data_gen(device="meta"):
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
return gpt2_model_builder, gpt2_data_gen

View File

@ -0,0 +1,150 @@
import time
import pytest
from functools import partial
import torch
from torch.utils._pytree import tree_map
import torch.multiprocessing as mp
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.fx.profiler import parameter_size
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.utils import free_port, get_current_device
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.testing import parameterize
from tests.test_tensor.common_utils import set_seed
from tests.test_auto_parallel.test_offload.model_utils import *
@parameterize('model_name', ['gpt2_'])
@parameterize('memory_budget', [5000])
@parameterize('solver_name', ['asyn'])
def exam_fwd_bwd(
model_name: str,
memory_budget: float,
solver_name: str
):
# build model
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen = get_components_func()
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
criterion = LMLoss()
set_seed(42)
start_time = time.time()
model = model_builder()
model.train()
param_size = parameter_size(model) / 1024 ** 2 / 2
init_time = time.time() - start_time
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
data_args = data_gen(device="cpu")
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
data_args = tree_map(wrap_fn, data_args)
start_time = time.time()
model = memory_optimize(model, data_args, memory_budget * 1024 * 1024, solver_name)
solver_time = time.time() - start_time
print(f"solver_time={solver_time:.3f} s")
hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
optim = AMPOptimizer(hybrid_optimizer, model)
with ColoInitContext(device=torch.device('cpu')):
gemini_model = model_builder()
gemini_model.train()
hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
gemini_config = dict(strict_ddp_mode=False,
device=torch.device('cpu'),
placement_policy='cpu',
pin_memory=True,
hidden_dim=8192,
search_range_mb=128)
gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
# test gemini
time_list = []
set_seed(42)
data_args = data_gen(device="cuda")
for step in range(10):
gemini_optim.zero_grad()
torch.cuda.synchronize()
start_time = time.time()
gemini_out = gemini_model(**data_args)
gemini_loss = criterion(gemini_out, label)
gemini_optim.backward(gemini_loss)
torch.cuda.synchronize()
time_list.append(time.time() - start_time)
gemini_optim.step()
torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
print(f'gemini | model_name: {model_name}')
print(
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(time_list)
del data_args
del gemini_model
del gemini_optim
del gemini_out
del gemini_loss
# test asyn offload
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
time_list = []
set_seed(42)
data_args = data_gen(device="cuda")
data_args = tree_map(wrap_fn, data_args)
for step in range(10):
optim.zero_grad()
torch.cuda.synchronize()
start_time = time.time()
loss = criterion(model(**data_args), label)
optim.backward(loss)
torch.cuda.synchronize()
time_list.append(time.time() - start_time)
optim.step()
torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
print(f'solver_name: {solver_name} | model_name: {model_name}')
print(
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(time_list)
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def test_perf(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_fwd_bwd()
if __name__ == '__main__':
run_func = partial(test_perf, world_size=1, port=free_port())
mp.spawn(run_func, nprocs=1)

View File

@ -0,0 +1,62 @@
import pytest
import torch.fx
from torch.fx import GraphModule
from torch.utils._pytree import tree_map
from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.auto_parallel.offload.region_manager import RegionManager
from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML
from colossalai.testing import parameterize
from tests.test_auto_parallel.test_offload.model_utils import *
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@parameterize('model_name', ['gpt2_', 'bert_'])
@parameterize('memory_budget', [4000])
@parameterize('solver_name', ['syn', 'asyn'])
def solver_test(model_name: str,
memory_budget: float,
solver_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen = get_components_func()
data_args = data_gen(device="cpu")
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
data_args = tree_map(wrap_fn, data_args)
model = model_builder()
model.train()
model = model.cpu().half()
tracer = ColoTracer()
assert is_compatible_with_meta()
wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x
meta_args = tree_map(wrap_fn, data_args)
graph = tracer.trace(model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
interp = MetaInfoProp(gm)
interp.propagate(*meta_args.values())
region_manager = RegionManager(graph, solver_name=solver_name)
region_manager._pre_process()
region_list = region_manager.region_list
solver_cls = SolverFactory.create(solver_name)
memory_budget = memory_budget * 1024 * 1024
solver = solver_cls(region_list, memory_budget)
solver._call_solver()
assert solver.best_ts.peak_mem < memory_budget
print("****************** execution plan *******************")
for region in region_list:
need_offload = region.need_offload
to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None
print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
for region in region_list.__reversed__():
need_offload = region.need_offload
to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None
print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
if __name__ == '__main__':
solver_test()