You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/auto_parallel/offload/util.py

98 lines
2.7 KiB

from dataclasses import dataclass
from typing import List
import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp
from .region import Region
@dataclass
class NodeInfo:
node_id: int = 0
runtime_fwd_mem: float = 0
runtime_bwd_mem: float = 0
class NvDevicePower:
"""
NVIDIA GPU computing performance (TFLOPs).
"""
RTX3080_FP16 = 70
RTX3080_FP32 = 34.1
RTX3090_FP16 = 71
RTX3090_FP32 = 35.7
V100_FP16 = 31.4
V100_FP32 = 15.7
A100_FP16 = 78
A100_FP32 = 19.5
class GlobalRuntimeInfo(metaclass=SingletonMeta):
def __init__(self):
self.h2d_stream = torch.cuda.Stream()
self.d2h_stream = torch.cuda.Stream()
self.fwd_prefetch_event_map = {}
self.bwd_prefetch_event_map = {}
self.region_list = []
def compute_act_peak_mem(region_list: List[Region]) -> float:
act_peak_mem = 0
runtime_mem = 0
# forward
for region in region_list:
for node in region.nodes:
runtime_mem = runtime_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
act_peak_mem = max(runtime_mem, act_peak_mem)
# backward
bwd_deps = {}
for region in region_list.__reversed__():
for node in region.nodes.__reversed__():
runtime_mem -= calculate_fwd_out(node)
runtime_mem = runtime_mem + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
act_peak_mem = max(runtime_mem, act_peak_mem)
runtime_mem = runtime_mem - node.meta["bwd_mem_tmp"] - calculate_fwd_tmp(node)
# free bwd_mem_out
bwd_deps[node] = len(node.all_input_nodes)
for user_node in node.users:
if user_node in bwd_deps:
bwd_deps[user_node] -= 1
if bwd_deps[user_node] <= 0:
runtime_mem -= user_node.meta["bwd_mem_out"]
return act_peak_mem
def compute_max_param_mem(region_list: List[Region]) -> float:
return max(region.param_size for region in region_list)
def compute_total_param_mem(region_list: List[Region]) -> float:
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid)
def requires_upload_p_in_fwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
)
def requires_release_p_in_bwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or (
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
)
def requires_offload_g_in_bwd(region: Region):
return region.param_size and (region.r_id <= region.shared_rid)