ColossalAI/colossalai/auto_parallel/offload/training_simulator.py

425 lines
18 KiB
Python
Raw Normal View History

import bisect
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict, List
from torch.fx.node import Node
from .region import Region
from .util import *
@dataclass
class ExecutionPeriod:
start_time: float = 0
end_time: float = 0
class TrainingSimulator(ABC):
"""
The Training Simulator is used to simulate the training process.
It records computation, communication, and runtime memory during forward and backward passes.
Args:
region_list (List[Region]): represents the linearized DNN computing graph.
comp_power (float): the NVIDIA GPU FP16 computing power.
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
"""
def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
self.region_list = region_list
self.region_num = len(region_list)
self.runtime_mem: int = 0
self.peak_mem: int = 0
self.total_mem_saving: int = 0
self.fwd_node_mem: Dict[Node, float] = {}
self.bwd_node_mem: Dict[Node, float] = {}
# Node dependencies in backward pass
self.bwd_node_deps: Dict[Node, int] = {}
self.comp_power: float = comp_power
self.link_to_bandwidth: Dict[str, Dict[float, float]] = link_to_bw
@abstractmethod
def execute(self):
raise NotImplementedError
@abstractmethod
def _eval_fwd_mem_per_region(self, region: Region):
raise NotImplementedError
@abstractmethod
def _eval_bwd_mem_per_region(self, region: Region):
raise NotImplementedError
def _get_bandwidth(self, link: str, comm_volumn: float) -> float:
"""
Get the data transfer bandwidth.
Args:
link (str): the data transfer link.
comm_volumn (float): the amount of data transferred.
Returns:
float: the data transfer bandwidth.
"""
assert len(self.link_to_bandwidth)
if link not in self.link_to_bandwidth:
raise TypeError(f"Unknown data transfer link {link}")
# size_list = sorted(list(map(float, self.link_to_bandwidth[link].keys())))
size_list = sorted(self.link_to_bandwidth[link].keys())
d_idx = bisect.bisect_left(size_list, comm_volumn)
return self.link_to_bandwidth[link][size_list[d_idx]]
def _get_communication_overhead(self, link: str, comm_volumn: float) -> float:
return comm_volumn / self._get_bandwidth(link, comm_volumn)
def _get_computing_overhead(self, flop: float) -> float:
return flop / self.comp_power
class SynTrainingSimulator(TrainingSimulator):
def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
def execute(self):
"""
Simulate synchronous training process.
"""
for reg in self.region_list:
self._eval_fwd_mem_per_region(reg)
for reg in self.region_list.__reversed__():
self._eval_bwd_mem_per_region(reg)
def _eval_fwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the forward execution reaches the current region.
"""
# upload parameters of the current region
if requires_upload_p_in_fwd(self.region_list[region.shared_rid]):
self.runtime_mem += region.param_size
for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.fwd_node_mem[node] = self.runtime_mem
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
if region.need_offload:
self.runtime_mem -= region.param_size
def _eval_bwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the backward execution reaches the current region.
"""
# upload parameters of the current region
if region.need_offload:
self.runtime_mem += region.param_size
# add the gradient of the parameter
if region.r_id < region.shared_rid:
# gradient accumulation is required for shared parameters
self.runtime_mem += 2.0 * region.param_size
else:
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
for user_node in node.users:
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
raise ValueError(
f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!"
)
# release parameter and offload gradient in region
if region.r_id == region.shared_rid:
self.runtime_mem -= 2.0 * region.param_size
elif region.r_id < region.shared_rid:
self.runtime_mem -= 3.0 * region.param_size
elif self.region_list[region.shared_rid].need_offload:
self.runtime_mem -= region.param_size
class AsynTrainingSimulator(TrainingSimulator):
def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
self.iter_end_time: int = 0
# the last computation execution period
self.last_comp: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last parameter prefetch execution period
self.last_h2d: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last gradient offload execution period
self.last_d2h: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the forward computation execution period of the region
self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the forward parameter prefetch execution period of the region
self.fwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the backward computation execution period of the region
self.bwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the backward parameter prefetch execution period of the region
self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the gradient offload execution period of the region
# which is divided into those that are waiting and those that have been released
self.bwd_reg_to_offl_waiting: OrderedDict[int, ExecutionPeriod] = OrderedDict()
self.bwd_reg_to_offl_freed: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the region buffer, which records regions that are offloaded but not released
self.reg_buffer_to_free: List[int] = []
# node dependencies in backward pass
self.bwd_node_deps: Dict[Node, int] = {}
# the region execution flow,
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# when the execution reaches the i-th region.
self.fwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
self.bwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
def execute(self):
"""
Simulate asynchronous training process.
In forward pass, parameter prefetching is advanced by one region.
In backward pass, parameter prefetching is executed at the specified location,
and gradient offloading is urgent.
"""
for reg in self.region_list:
if reg.param_size and reg.r_id < self.region_num - 1:
for nr in self.region_list[reg.r_id + 1 :]:
if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]):
reg.fwd_prefetch_region = nr
break
self._eval_fwd_cost_per_region(reg)
self._eval_fwd_mem_per_region(reg)
for reg in self.region_list.__reversed__():
self._eval_bwd_cost_per_region(reg)
self._eval_bwd_mem_per_region(reg)
# release remaining grads
for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items():
self.bwd_reg_to_offl_freed[reg_id] = offl_exec
self.runtime_mem -= self.region_list[reg_id].param_size
self.bwd_reg_to_offl_waiting.clear()
self.iter_end_time = max(self.last_comp.end_time, self.last_d2h.end_time)
def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
"""
Insert parameter prefetch execution period of the current region to the end of the h2d stream
"""
pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time)
pref_end_time = pref_start_time + 2.0 * self._get_communication_overhead("h2d", region.param_size)
pref_ep = ExecutionPeriod(start_time=pref_start_time, end_time=pref_end_time)
if is_fwd:
self.fwd_reg_to_pref[region.r_id] = pref_ep
else:
self.bwd_reg_to_pref[region.r_id] = pref_ep
self.last_h2d = pref_ep
def _insert_comp_exec(self, region: Region, is_fwd: bool = True):
"""
Insert computation execution period of the current region to the end of the computing stream
"""
if is_fwd:
reg_to_comp = self.fwd_reg_to_comp
reg_to_pref = self.fwd_reg_to_pref
flop_key = "fwd_flop"
else:
reg_to_comp = self.bwd_reg_to_comp
reg_to_pref = self.bwd_reg_to_pref
flop_key = "bwd_flop"
comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(region.r_id, ExecutionPeriod(0, 0)).end_time)
comp_end_time = comp_start_time + sum(
[self._get_computing_overhead(node.meta.get(flop_key, 0)) for node in region.nodes]
)
comp_ep = ExecutionPeriod(start_time=comp_start_time, end_time=comp_end_time)
reg_to_comp[region.r_id] = comp_ep
self.last_comp = comp_ep
def _insert_d2h_exec(self, region: Region):
"""
Insert gradient offload execution period of the current region to the end of the d2h stream
"""
offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time)
offl_end_time = offl_start_time + self._get_communication_overhead("d2h", region.param_size)
offl_ep = ExecutionPeriod(start_time=offl_start_time, end_time=offl_end_time)
self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep
self.last_d2h = offl_ep
def _eval_fwd_cost_per_region(self, region: Region):
"""
Evaluate computation and communication execution period of the region in forward pass.
"""
# upload parameters of the first region
if region.r_id == 0:
self._insert_h2d_exec(region)
# prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self._insert_h2d_exec(fwd_prefetch_region)
# execute computation
self._insert_comp_exec(region)
def _eval_fwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the forward execution reaches the current region.
"""
# upload parameters of the current region
if region.r_id <= 0:
self.runtime_mem += region.param_size
self.fwd_reg_flow[region.r_id, region.r_id] = True
else:
self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1]
self.fwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
self.reg_buffer_to_free.clear()
# prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self.runtime_mem += fwd_prefetch_region.param_size
self.fwd_reg_flow[region.r_id, fwd_prefetch_region.r_id] = True
for node in region.nodes:
self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
self.fwd_node_mem[node] = self.runtime_mem
if region.need_offload:
self.runtime_mem -= region.param_size
assert len(self.reg_buffer_to_free) <= 1, f"{len(self.reg_buffer_to_free)}"
self.reg_buffer_to_free.append(region.r_id)
def _eval_bwd_cost_per_region(self, region: Region):
"""
Evaluate computation and communication execution period of the region in backward pass.
"""
# upload parameters of the current region
if region.is_syn:
assert region.need_offload
self._insert_h2d_exec(region, is_fwd=False)
# prefetch parameters of the region choiced, which is parallel to computation
if region.bwd_prefetch_region is not None:
self._insert_h2d_exec(region.bwd_prefetch_region, is_fwd=False)
# execute computation
self._insert_comp_exec(region, is_fwd=False)
# offload gradient
if requires_offload_g_in_bwd(region):
self._insert_d2h_exec(region)
assert len(self.reg_buffer_to_free) == 0
for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items():
if offl_exec.end_time >= self.last_comp.start_time:
break
self.reg_buffer_to_free.append(reg_id)
self.bwd_reg_to_offl_freed[reg_id] = offl_exec
for reg_id in self.reg_buffer_to_free:
self.bwd_reg_to_offl_waiting.pop(reg_id)
def _eval_bwd_mem_per_region(self, region: Region):
"""
Evaluate the runtime and peak memory when the backward execution reaches the current region.
"""
if region.r_id + 1 < self.region_num:
self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1]
else:
self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1]
self.bwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
# free gradients in the buffer
while len(self.reg_buffer_to_free):
reg_id = self.reg_buffer_to_free.pop(0)
self.runtime_mem -= self.region_list[reg_id].param_size
# upload parameters of the current region
if region.is_syn:
self.runtime_mem += region.param_size
self.bwd_reg_flow[region.r_id, region.r_id] = True
# prefetch parameters of the region choiced
bwd_prefetch_region = region.bwd_prefetch_region
if bwd_prefetch_region:
self.runtime_mem += bwd_prefetch_region.param_size
self.bwd_reg_flow[region.r_id, bwd_prefetch_region.r_id] = True
# add the gradient of the parameter
if region.r_id < region.shared_rid:
# gradient accumulation is required for shared parameters
self.runtime_mem += 2.0 * region.param_size
else:
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
self.runtime_mem -= calculate_fwd_out(node)
self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem
self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
for user_node in node.users:
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
raise ValueError(
f"region id: {region.r_id}, node name: {node.name}, "
f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
f"runtime memory computed less than 0, which is miscalculated!"
)
# release parameters of the region
if requires_release_p_in_bwd(self.region_list[region.shared_rid]):
self.runtime_mem -= region.param_size