mirror of https://github.com/hpcaitech/ColossalAI
499 lines
18 KiB
Python
499 lines
18 KiB
Python
import time
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Type
|
|
|
|
NOT_NVML = False
|
|
try:
|
|
from pynvml import *
|
|
except:
|
|
NOT_NVML = True
|
|
|
|
import torch
|
|
from torch.fx.node import Node
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
from .region import Region
|
|
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
|
|
from .util import NodeInfo, NvDevicePower
|
|
|
|
|
|
def benchmark_func(func, number=1, repeat=1, warmup=3):
|
|
"""
|
|
benchmark data transfer cost.
|
|
"""
|
|
|
|
for i in range(warmup):
|
|
func()
|
|
|
|
costs = []
|
|
|
|
for i in range(repeat):
|
|
torch.cuda.synchronize()
|
|
begin = time.time()
|
|
for i in range(number):
|
|
func()
|
|
torch.cuda.synchronize()
|
|
costs.append((time.time() - begin) / number)
|
|
|
|
return sum(costs) / len(costs)
|
|
|
|
|
|
class Solver(ABC):
|
|
"""
|
|
The parameter offload solver.
|
|
|
|
Args:
|
|
region_list (List[Region]): represents the linearized DNN computing graph.
|
|
memory_budget (float): the given memory budget.
|
|
error_factor (float): the error factor.
|
|
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
|
|
"""
|
|
|
|
def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None:
|
|
self.region_list = region_list
|
|
|
|
self.error_factor: float = error_factor
|
|
if memory_budget > 0:
|
|
self.memory_budget = memory_budget * self.error_factor
|
|
else:
|
|
self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
|
|
|
|
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
|
|
self.comp_power: float = self._extract_computing_power()
|
|
|
|
@abstractmethod
|
|
def _call_solver(self):
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def _try_to_offload(self, *args):
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def _eval_one_choice(self, *args):
|
|
raise NotImplementedError
|
|
|
|
def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: float, extra_cost: float):
|
|
"""
|
|
Compute the profits of the offload strategies,
|
|
which packages the memory savings information for subsequent comparisons.
|
|
|
|
Args:
|
|
total_mem_saving (float): the total memory saving of the offload strategy.
|
|
peak_mem_saving (float): the peak memory saving of the offload strategy.
|
|
extra_cost (float): extra data transfer cost.
|
|
|
|
Returns:
|
|
tuple: profit information, the first term represents memory savings per unit of time.
|
|
"""
|
|
|
|
if extra_cost == 0:
|
|
# means data transfer overhead can be completely overlapped
|
|
return (float("inf"), total_mem_saving, peak_mem_saving)
|
|
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
|
|
|
|
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
|
|
"""
|
|
Compare the profits of the two offload strategies using the dictionary order algorithm.
|
|
|
|
Args:
|
|
profit_a (tuple): the profit of a offload strategy.
|
|
profit_b (tuple): the profit of another offload strategy.
|
|
|
|
Returns:
|
|
bool: whether profit_a is greater than profit_b.
|
|
"""
|
|
|
|
for val1, val2 in zip(profit_a, profit_b):
|
|
if val1 != val2:
|
|
return val1 > val2
|
|
return False
|
|
|
|
def _update_state(self, best_ts: TrainingSimulator):
|
|
"""
|
|
Update the solver state.
|
|
"""
|
|
|
|
self.best_ts = best_ts
|
|
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
|
|
|
|
def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]):
|
|
"""
|
|
Update the runtime memory information of the node.
|
|
|
|
Args:
|
|
fwd_mem_info (Dict[Node, float]): the runtime memory of each node in forward pass.
|
|
bwd_mem_info (Dict[Node, float]): the runtime memory of each node in backward pass.
|
|
"""
|
|
|
|
for node, mem in fwd_mem_info.items():
|
|
assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
|
|
node.node_info.runtime_fwd_mem = mem
|
|
for node, mem in bwd_mem_info.items():
|
|
assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
|
|
node.node_info.runtime_bwd_mem = mem
|
|
|
|
def _extract_computing_power(self):
|
|
"""
|
|
return the FP16 computing performance of the current NVIDIA GPU.
|
|
|
|
Raises:
|
|
TypeError: Unknown NVIDIA GPU device.
|
|
"""
|
|
|
|
nvmlInit()
|
|
handle = nvmlDeviceGetHandleByIndex(0)
|
|
device_name = nvmlDeviceGetName(handle)
|
|
units = 1e12
|
|
|
|
if device_name.__contains__("RTX 3080"):
|
|
return NvDevicePower.RTX3080_FP16 * units
|
|
elif device_name.__contains__("RTX 3090"):
|
|
return NvDevicePower.RTX3090_FP16 * units
|
|
elif device_name.__contains__("V100"):
|
|
return NvDevicePower.V100_FP16 * units
|
|
elif device_name.__contains__("A100"):
|
|
return NvDevicePower.A100_FP16 * units
|
|
else:
|
|
raise TypeError(f"Unknown NVIDIA GPU device name {device_name}")
|
|
|
|
def _profile_bandwidth(self):
|
|
"""
|
|
Profile the bidirectional communication bandwidth between CPU and GPU
|
|
using data volumes ranging from 1KB to 1GB.
|
|
"""
|
|
|
|
print("profiling bandwidth ......")
|
|
link_to_bandwidth = {}
|
|
links = ["h2d", "d2h"]
|
|
|
|
for link in links:
|
|
t_size = 1024
|
|
size_to_bandwidth = {}
|
|
|
|
# from 1KB to 1GB
|
|
for i in range(21):
|
|
if link == "h2d":
|
|
src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True)
|
|
dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device="cuda")
|
|
elif link == "d2h":
|
|
src_tensor = torch.ones(int(t_size), dtype=torch.int8, device="cuda")
|
|
dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True)
|
|
|
|
def func():
|
|
dst_tensor.copy_(src_tensor)
|
|
|
|
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
|
|
print(
|
|
f"size: {t_size / 1024 ** 2:.3f} MB, "
|
|
f"{src_tensor.device.type}-to-{dst_tensor.device.type} "
|
|
f"bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s"
|
|
)
|
|
|
|
t_size *= 2
|
|
|
|
link_to_bandwidth[link] = size_to_bandwidth
|
|
return link_to_bandwidth
|
|
|
|
|
|
class SynGreedySolver(Solver):
|
|
def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None:
|
|
super().__init__(region_list, memory_budget)
|
|
|
|
self.best_ts: SynTrainingSimulator = None
|
|
self._init_state()
|
|
|
|
def _init_state(self):
|
|
"""
|
|
Initialize the solver state when without offloading.
|
|
"""
|
|
|
|
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
|
|
ts.execute()
|
|
self._update_state(ts)
|
|
|
|
def _call_solver(self):
|
|
"""
|
|
Call the solver to search an efficient parameter offloading strategy for the linearized graph.
|
|
The solver adopts greedy algorithm.
|
|
|
|
Raises:
|
|
NotImplementedError: Unable to find a solution for the given memory budget.
|
|
"""
|
|
|
|
print("search offloading strategy ......")
|
|
while self.best_ts.peak_mem > self.memory_budget:
|
|
offload_region = None
|
|
best_ts = None
|
|
max_profit = (0,)
|
|
|
|
# search which region should be offloaded,
|
|
# the last region does not need to be offloaded.
|
|
for region in self.region_list[:-1]:
|
|
if region.param_size and not region.need_offload:
|
|
temp_ts, profit = self._try_to_offload(region)
|
|
if self._compare_profit(profit, max_profit):
|
|
offload_region = region
|
|
max_profit = profit
|
|
best_ts = temp_ts
|
|
|
|
if offload_region is not None and best_ts is not None:
|
|
offload_region.need_offload = True
|
|
offload_region.is_syn = True
|
|
self._update_state(best_ts)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
|
|
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
|
|
)
|
|
|
|
def _call_solver_l2l(self):
|
|
"""
|
|
The layer-wise offload strategy.
|
|
"""
|
|
|
|
for region in self.region_list[:-1]:
|
|
region.need_offload = True
|
|
region.is_syn = True
|
|
|
|
def _try_to_offload(self, offload_region: Region):
|
|
# record previous information
|
|
orig_need_offload = offload_region.need_offload
|
|
assert not orig_need_offload
|
|
offload_region.need_offload = True
|
|
|
|
ts, profit = self._eval_one_choice(offload_region)
|
|
|
|
# restore previous information
|
|
offload_region.need_offload = orig_need_offload
|
|
return ts, profit
|
|
|
|
def _eval_one_choice(self, offload_region: Region):
|
|
"""
|
|
Evaluate the profit of a strategy choice.
|
|
|
|
Args:
|
|
offload_region (Region): the offload region of current choice.
|
|
|
|
Returns:
|
|
SynTrainingSimulator: the training simulator corresponding to the current strategy.
|
|
tuple: contains memory saving and cost information of the current strategy.
|
|
"""
|
|
|
|
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
|
|
ts.execute()
|
|
|
|
extra_comm_cost = 2.0 * ts._get_communication_overhead("h2d", offload_region.param_size)
|
|
# the shared region needs to be moved twice
|
|
if offload_region.r_id < offload_region.shared_rid:
|
|
extra_comm_cost *= 2.0
|
|
profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
|
|
|
|
return ts, profit
|
|
|
|
|
|
class AsynGreedySolver(Solver):
|
|
def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3):
|
|
super().__init__(region_list, memory_budget)
|
|
|
|
self.search_window_size = search_window_size
|
|
# Records the prefetch execution location of the offloaded region
|
|
self.region_to_region_map = {}
|
|
self.best_ts: AsynTrainingSimulator = None
|
|
|
|
self._init_state()
|
|
|
|
def _init_state(self):
|
|
"""
|
|
Initialize the solver state when without offloading.
|
|
"""
|
|
|
|
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
|
|
ts.execute()
|
|
self._update_state(ts)
|
|
print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB")
|
|
|
|
def _call_solver(self):
|
|
"""
|
|
Call the solver to search an efficient parameter offloading strategy for the linearized graph.
|
|
The solver adopts greedy algorithm.
|
|
|
|
Raises:
|
|
NotImplementedError: Unable to find a solution for the given memory budget.
|
|
"""
|
|
|
|
print("search for offloading strategy ......")
|
|
# Records the prefetch execution location of the offloaded region
|
|
region_to_region_map = {}
|
|
while self.best_ts.peak_mem > self.memory_budget:
|
|
region_to_offload = None
|
|
max_offload_profit = (0,)
|
|
best_offl_ts = None
|
|
|
|
# search which region should be offloaded,
|
|
# the last region does not need to be offloaded
|
|
for region in self.region_list[:-1]:
|
|
if region.param_size and not region.need_offload:
|
|
max_prefetch_profit = (0,)
|
|
best_pref_ts = None
|
|
|
|
# search when to prefetch the region offloaded
|
|
for host_region in self.region_list[region.r_id + 1 : region.r_id + 1 + self.search_window_size]:
|
|
if host_region.bwd_prefetch_region is not None:
|
|
continue
|
|
|
|
temp_ts, profit = self._try_to_offload(host_region, region)
|
|
|
|
if self._compare_profit(profit, max_prefetch_profit):
|
|
region_to_region_map[region.r_id] = host_region
|
|
max_prefetch_profit = profit
|
|
best_pref_ts = temp_ts
|
|
if profit[0] == float("inf"):
|
|
break
|
|
|
|
if self._compare_profit(max_prefetch_profit, max_offload_profit):
|
|
region_to_offload = region
|
|
max_offload_profit = max_prefetch_profit
|
|
best_offl_ts = best_pref_ts
|
|
|
|
if (region_to_offload is not None) and (best_offl_ts is not None):
|
|
region_to_offload.need_offload = True
|
|
if region_to_region_map[region_to_offload.r_id] == region_to_offload:
|
|
region_to_offload.is_syn = True
|
|
else:
|
|
region_to_region_map[region_to_offload.r_id].bwd_prefetch_region = region_to_offload
|
|
self.region_to_region_map[region_to_offload.r_id] = region_to_region_map[region_to_offload.r_id]
|
|
|
|
self._update_state(best_offl_ts)
|
|
|
|
elif self.region_to_region_map.__len__() > 0:
|
|
self._repair_strategy()
|
|
else:
|
|
raise NotImplementedError(
|
|
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
|
|
f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
|
|
)
|
|
|
|
region_to_region_map.clear()
|
|
|
|
def _try_to_offload(self, host_region: Region, offload_region: Region):
|
|
"""
|
|
Attempts to offload the region and prefetch it in backward pass.
|
|
"""
|
|
|
|
# record previous information
|
|
orig_prefetch = host_region.bwd_prefetch_region
|
|
orig_is_syn = offload_region.is_syn
|
|
orig_need_offload = offload_region.need_offload
|
|
|
|
if host_region == offload_region:
|
|
offload_region.is_syn = True
|
|
else:
|
|
host_region.bwd_prefetch_region = offload_region
|
|
offload_region.need_offload = True
|
|
|
|
ts, profit = self._eval_one_choice()
|
|
|
|
# restore previous information
|
|
host_region.bwd_prefetch_region = orig_prefetch
|
|
offload_region.is_syn = orig_is_syn
|
|
offload_region.need_offload = orig_need_offload
|
|
|
|
return ts, profit
|
|
|
|
def _try_convert_to_syn_upload(self, host_region: Region, offload_region: Region):
|
|
"""
|
|
Attempts to convert asynchronous prefetch into synchronous upload operations.
|
|
"""
|
|
|
|
# record previous information
|
|
orig_prefetch = host_region.bwd_prefetch_region
|
|
orig_is_syn = offload_region.is_syn
|
|
assert orig_prefetch is not None and not orig_is_syn
|
|
|
|
host_region.bwd_prefetch_region = None
|
|
offload_region.is_syn = True
|
|
|
|
ts, profit = self._eval_one_choice()
|
|
|
|
# restore previous information
|
|
host_region.bwd_prefetch_region = orig_prefetch
|
|
offload_region.is_syn = orig_is_syn
|
|
|
|
return ts, profit
|
|
|
|
def _repair_strategy(self):
|
|
"""
|
|
Repair offload strategy.
|
|
It attempts to convert asynchronous prefetch into synchronous upload operations and selects the best one.
|
|
The repair process does not end until peak memory is reduced or there is no asynchronous prefetch operation.
|
|
"""
|
|
print("repair strategy ......")
|
|
|
|
peak_mem_saving = 0
|
|
while len(self.region_to_region_map) and peak_mem_saving <= 0:
|
|
max_profit = (0,)
|
|
best_ts = None
|
|
undo_host_region = None
|
|
undo_offload_region = None
|
|
|
|
for offload_region_id, host_region in self.region_to_region_map.items():
|
|
offload_region = self.region_list[offload_region_id]
|
|
assert host_region.bwd_prefetch_region == offload_region
|
|
assert offload_region.need_offload
|
|
assert not offload_region.is_syn
|
|
|
|
ts, profit = self._try_convert_to_syn_upload(host_region, offload_region)
|
|
|
|
if self._compare_profit(profit, max_profit):
|
|
undo_host_region = host_region
|
|
undo_offload_region = offload_region
|
|
max_profit = profit
|
|
best_ts = ts
|
|
|
|
if best_ts is None:
|
|
raise NotImplementedError("repair error!")
|
|
|
|
assert not undo_offload_region.is_syn
|
|
undo_offload_region.is_syn = True
|
|
undo_host_region.bwd_prefetch_region = None
|
|
|
|
peak_mem_saving = self.best_ts.peak_mem - best_ts.peak_mem
|
|
|
|
self._update_state(best_ts)
|
|
self.region_to_region_map.pop(undo_offload_region.r_id)
|
|
|
|
return best_ts
|
|
|
|
def _eval_one_choice(self):
|
|
"""
|
|
Evaluate the profit of a strategy choice.
|
|
|
|
Returns:
|
|
AsynTrainingSimulator: the training simulator corresponding to the current strategy.
|
|
tuple: contains memory saving and cost information of the current strategy.
|
|
"""
|
|
|
|
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
|
|
ts.execute()
|
|
|
|
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
|
|
profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
|
|
|
|
return ts, profit
|
|
|
|
|
|
class SolverFactory:
|
|
solvers: Dict[str, Type[Solver]] = {"syn": SynGreedySolver, "asyn": AsynGreedySolver}
|
|
|
|
@staticmethod
|
|
def create(solver_name: str) -> Type[Solver]:
|
|
if solver_name not in SolverFactory.solvers:
|
|
raise TypeError(f"Unknown parameter offload policy {solver_name}")
|
|
return SolverFactory.solvers[solver_name]
|
|
|
|
@staticmethod
|
|
def get_solver_names():
|
|
return tuple(SolverFactory.solvers.keys())
|