mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
2.7 KiB
97 lines
2.7 KiB
from dataclasses import dataclass |
|
from typing import List |
|
|
|
import torch |
|
|
|
from colossalai.context.singleton_meta import SingletonMeta |
|
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp |
|
|
|
from .region import Region |
|
|
|
|
|
@dataclass |
|
class NodeInfo: |
|
node_id: int = 0 |
|
runtime_fwd_mem: float = 0 |
|
runtime_bwd_mem: float = 0 |
|
|
|
|
|
class NvDevicePower: |
|
""" |
|
NVIDIA GPU computing performance (TFLOPs). |
|
""" |
|
|
|
RTX3080_FP16 = 70 |
|
RTX3080_FP32 = 34.1 |
|
|
|
RTX3090_FP16 = 71 |
|
RTX3090_FP32 = 35.7 |
|
|
|
V100_FP16 = 31.4 |
|
V100_FP32 = 15.7 |
|
|
|
A100_FP16 = 78 |
|
A100_FP32 = 19.5 |
|
|
|
|
|
class GlobalRuntimeInfo(metaclass=SingletonMeta): |
|
def __init__(self): |
|
self.h2d_stream = torch.cuda.Stream() |
|
self.d2h_stream = torch.cuda.Stream() |
|
self.fwd_prefetch_event_map = {} |
|
self.bwd_prefetch_event_map = {} |
|
self.region_list = [] |
|
|
|
|
|
def compute_act_peak_mem(region_list: List[Region]) -> float: |
|
act_peak_mem = 0 |
|
runtime_mem = 0 |
|
# forward |
|
for region in region_list: |
|
for node in region.nodes: |
|
runtime_mem = runtime_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node) |
|
act_peak_mem = max(runtime_mem, act_peak_mem) |
|
# backward |
|
bwd_deps = {} |
|
for region in region_list.__reversed__(): |
|
for node in region.nodes.__reversed__(): |
|
runtime_mem -= calculate_fwd_out(node) |
|
runtime_mem = runtime_mem + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] |
|
|
|
act_peak_mem = max(runtime_mem, act_peak_mem) |
|
|
|
runtime_mem = runtime_mem - node.meta["bwd_mem_tmp"] - calculate_fwd_tmp(node) |
|
|
|
# free bwd_mem_out |
|
bwd_deps[node] = len(node.all_input_nodes) |
|
for user_node in node.users: |
|
if user_node in bwd_deps: |
|
bwd_deps[user_node] -= 1 |
|
if bwd_deps[user_node] <= 0: |
|
runtime_mem -= user_node.meta["bwd_mem_out"] |
|
|
|
return act_peak_mem |
|
|
|
|
|
def compute_max_param_mem(region_list: List[Region]) -> float: |
|
return max(region.param_size for region in region_list) |
|
|
|
|
|
def compute_total_param_mem(region_list: List[Region]) -> float: |
|
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid) |
|
|
|
|
|
def requires_upload_p_in_fwd(shared_reg: Region): |
|
return (shared_reg.r_id >= shared_reg.shared_rid) or ( |
|
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload |
|
) |
|
|
|
|
|
def requires_release_p_in_bwd(shared_reg: Region): |
|
return (shared_reg.r_id >= shared_reg.shared_rid) or ( |
|
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload |
|
) |
|
|
|
|
|
def requires_offload_g_in_bwd(region: Region): |
|
return region.param_size and (region.r_id <= region.shared_rid)
|
|
|