[test] fixed gemini plugin test (#3411)

* [test] fixed gemini plugin test

* polish code

* polish code
pull/3418/head
Frank Lee 2023-04-03 17:12:22 +08:00 committed by GitHub
parent 30412866e0
commit 638a07a7f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 124 additions and 131 deletions

View File

@ -1,10 +1,11 @@
from typing import Optional, Set
from functools import partial from functools import partial
from typing import Optional, Set
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.gemini.tensor_utils import free_storage from colossalai.gemini.tensor_utils import free_storage
from colossalai.nn.parallel.data_parallel import _cast_float
from .region_manager import RegionManager from .region_manager import RegionManager
from .util import GlobalRuntimeInfo from .util import GlobalRuntimeInfo
@ -20,10 +21,7 @@ class BaseOffloadModule:
is_sync (bool): synchronous mode or not. is_sync (bool): synchronous mode or not.
""" """
def __init__(self, def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):
model: nn.Module,
region_manager: RegionManager,
is_sync=True):
self.model = model self.model = model
self.region_manager = region_manager self.region_manager = region_manager
@ -69,8 +67,8 @@ class BaseOffloadModule:
for p in self.model.parameters(): for p in self.model.parameters():
p.grad = None p.grad = None
GlobalRuntimeInfo.fwd_prefetch_event_map.clear() GlobalRuntimeInfo().fwd_prefetch_event_map.clear()
GlobalRuntimeInfo.bwd_prefetch_event_map.clear() GlobalRuntimeInfo().bwd_prefetch_event_map.clear()
def grad_handle(self, p, grad): def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad) empty_grad = torch.empty_like(grad)
@ -82,7 +80,7 @@ class BaseOffloadModule:
self.overflow_counter += region.has_inf_or_nan self.overflow_counter += region.has_inf_or_nan
master_stream = torch.cuda.current_stream() master_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.grad_offload_stream): with torch.cuda.stream(self.grad_offload_stream):
GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream) GlobalRuntimeInfo().d2h_stream.wait_stream(master_stream)
region.move_grad_to_cpu() region.move_grad_to_cpu()
return empty_grad return empty_grad

View File

@ -1,4 +1,5 @@
from typing import Dict from typing import Dict
import torch import torch
import torch.fx import torch.fx
from torch.fx import GraphModule from torch.fx import GraphModule
@ -7,10 +8,11 @@ from torch.utils._pytree import tree_map
from colossalai.fx import ColoTracer, is_compatible_with_meta from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from .region_manager import RegionManager
from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass
from .base_offload_module import BaseOffloadModule from .base_offload_module import BaseOffloadModule
from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo from .region_manager import RegionManager
from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_pass
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem
def memory_optimize(model: torch.nn.Module, def memory_optimize(model: torch.nn.Module,
inps: Dict[str, torch.Tensor], inps: Dict[str, torch.Tensor],
@ -29,13 +31,14 @@ def memory_optimize(model: torch.nn.Module,
region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget) region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget)
region_manager._build_regions() region_manager._build_regions()
GlobalRuntimeInfo.region_list = region_manager.region_list GlobalRuntimeInfo().region_list = region_manager.region_list
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2 act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024**2
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2 max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024**2
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2 total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024**2
print( print(
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}") f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
)
if solver_name == 'syn': if solver_name == 'syn':
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list) gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
@ -45,5 +48,5 @@ def memory_optimize(model: torch.nn.Module,
raise TypeError(f"Unknown solver name {solver_name}!") raise TypeError(f"Unknown solver name {solver_name}!")
gm.recompile() gm.recompile()
optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn') optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
return optimized_model return optimized_model

View File

@ -1,4 +1,5 @@
from typing import List from typing import List
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
@ -23,13 +24,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
ctx.bwd_info = bwd_info ctx.bwd_info = bwd_info
d2h_rid = fwd_info.get('d2h_rid', None) d2h_rid = fwd_info.get('d2h_rid', None)
if d2h_rid is not None: if d2h_rid is not None:
free_region = GlobalRuntimeInfo.region_list[d2h_rid] free_region = GlobalRuntimeInfo().region_list[d2h_rid]
assert isinstance(free_region, Region) assert isinstance(free_region, Region)
free_region.free_cuda_data() free_region.free_cuda_data()
h2d_rid = fwd_info.get('h2d_rid', None) h2d_rid = fwd_info.get('h2d_rid', None)
if h2d_rid is not None: if h2d_rid is not None:
h2d_region = GlobalRuntimeInfo.region_list[h2d_rid] h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(h2d_region, Region) assert isinstance(h2d_region, Region)
h2d_region.move_param_to_cuda() h2d_region.move_param_to_cuda()
@ -40,7 +41,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
h2d_rid = ctx.bwd_info.get('h2d_rid', None) h2d_rid = ctx.bwd_info.get('h2d_rid', None)
if h2d_rid is not None: if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid] pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region) assert isinstance(pref_region, Region)
pref_region.move_param_to_cuda() pref_region.move_param_to_cuda()
@ -65,23 +66,22 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
sync_rid = fwd_info.get('sync_rid', None) sync_rid = fwd_info.get('sync_rid', None)
if sync_rid is not None: if sync_rid is not None:
prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get( prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
sync_rid, None)
if prefetch_event: if prefetch_event:
prefetch_event.wait() prefetch_event.wait()
h2d_rid = fwd_info.get('h2d_rid', None) h2d_rid = fwd_info.get('h2d_rid', None)
if h2d_rid is not None: if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid] pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region) assert isinstance(pref_region, Region)
master_stream = torch.cuda.current_stream() master_stream = torch.cuda.current_stream()
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream): with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream) GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
pref_region.move_param_to_cuda() pref_region.move_param_to_cuda()
prefetch_event = torch.cuda.Event() prefetch_event = torch.cuda.Event()
prefetch_event.record(GlobalRuntimeInfo.h2d_stream) prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
GlobalRuntimeInfo.fwd_prefetch_event_map[h2d_rid] = prefetch_event GlobalRuntimeInfo().fwd_prefetch_event_map[h2d_rid] = prefetch_event
return input_ return input_
@ -90,10 +90,9 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
sync_rid = ctx.bwd_info.get('sync_rid', None) sync_rid = ctx.bwd_info.get('sync_rid', None)
if sync_rid is not None: if sync_rid is not None:
wait_region = GlobalRuntimeInfo.region_list[sync_rid] wait_region = GlobalRuntimeInfo().region_list[sync_rid]
assert isinstance(wait_region, Region) assert isinstance(wait_region, Region)
prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get( prefetch_event = GlobalRuntimeInfo().bwd_prefetch_event_map.get(sync_rid, None)
sync_rid, None)
if prefetch_event: if prefetch_event:
prefetch_event.wait() prefetch_event.wait()
else: else:
@ -101,16 +100,16 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
h2d_rid = ctx.bwd_info.get('h2d_rid', None) h2d_rid = ctx.bwd_info.get('h2d_rid', None)
if h2d_rid is not None: if h2d_rid is not None:
pref_region = GlobalRuntimeInfo.region_list[h2d_rid] pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region) assert isinstance(pref_region, Region)
master_stream = torch.cuda.current_stream() master_stream = torch.cuda.current_stream()
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream): with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream) GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
pref_region.move_param_to_cuda() pref_region.move_param_to_cuda()
prefetch_event = torch.cuda.Event() prefetch_event = torch.cuda.Event()
prefetch_event.record(GlobalRuntimeInfo.h2d_stream) prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
GlobalRuntimeInfo.bwd_prefetch_event_map[h2d_rid] = prefetch_event GlobalRuntimeInfo().bwd_prefetch_event_map[h2d_rid] = prefetch_event
return grad_output, None, None return grad_output, None, None
@ -129,6 +128,7 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret return ret
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
''' '''
Convert Prefetch and Offload operation into runtime action. Convert Prefetch and Offload operation into runtime action.
@ -189,7 +189,8 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
if fwd_info or bwd_info: if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node): with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action, new_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info)) args=(last_inp_node, fwd_info, bwd_info))
replace_node_users(last_inp_node, new_node) replace_node_users(last_inp_node, new_node)
@ -206,11 +207,11 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
# upload parameters of the first region # upload parameters of the first region
last_inp_node = tuple(mod_graph.nodes)[0] last_inp_node = tuple(mod_graph.nodes)[0]
first_region_with_p = [ first_region_with_p = [region for region in region_list if region.param_size][0]
region for region in region_list if region.param_size][0]
fwd_info = {"h2d_rid": first_region_with_p.r_id} fwd_info = {"h2d_rid": first_region_with_p.r_id}
with mod_graph.inserting_after(last_inp_node): with mod_graph.inserting_after(last_inp_node):
upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action, upload_apply_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, {})) args=(last_inp_node, fwd_info, {}))
replace_node_users(last_inp_node, upload_apply_node) replace_node_users(last_inp_node, upload_apply_node)
last_inp_node = upload_apply_node last_inp_node = upload_apply_node
@ -225,19 +226,20 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
# forward offload # forward offload
if r_idx > 0 and region_list[r_idx-1].need_offload: if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1 fwd_info['d2h_rid'] = r_idx - 1
bwd_info = {} bwd_info = {}
# backward prefetch # backward prefetch
if r_idx > 0 and region_list[r_idx-1].need_offload: if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['sync_rid'] = r_idx - 1 bwd_info['sync_rid'] = r_idx - 1
if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region: if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id
if fwd_info or bwd_info: if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node): with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action, new_node = mod_graph.create_node('call_function',
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info)) args=(last_inp_node, fwd_info, bwd_info))
replace_node_users(last_inp_node, new_node) replace_node_users(last_inp_node, new_node)
@ -246,7 +248,8 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
if region.bwd_prefetch_region: if region.bwd_prefetch_region:
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id} bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
with mod_graph.inserting_after(last_inp_node): with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action, new_node = mod_graph.create_node('call_function',
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, {}, bwd_info)) args=(last_inp_node, {}, bwd_info))
replace_node_users(last_inp_node, new_node) replace_node_users(last_inp_node, new_node)
# gm.graph.print_tabular() # gm.graph.print_tabular()

View File

@ -1,6 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
import torch import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp
from .region import Region from .region import Region
@ -12,6 +15,7 @@ class NodeInfo:
runtime_fwd_mem: float = 0 runtime_fwd_mem: float = 0
runtime_bwd_mem: float = 0 runtime_bwd_mem: float = 0
class NvDevicePower: class NvDevicePower:
""" """
NVIDIA GPU computing performance (TFLOPs). NVIDIA GPU computing performance (TFLOPs).
@ -30,12 +34,14 @@ class NvDevicePower:
A100_FP32 = 19.5 A100_FP32 = 19.5
class GlobalRuntimeInfo: class GlobalRuntimeInfo(metaclass=SingletonMeta):
h2d_stream = torch.cuda.Stream()
d2h_stream = torch.cuda.Stream() def __init__(self):
fwd_prefetch_event_map = {} self.h2d_stream = torch.cuda.Stream()
bwd_prefetch_event_map = {} self.d2h_stream = torch.cuda.Stream()
region_list = [] self.fwd_prefetch_event_map = {}
self.bwd_prefetch_event_map = {}
self.region_list = []
def compute_act_peak_mem(region_list: List[Region]) -> float: def compute_act_peak_mem(region_list: List[Region]) -> float:
@ -70,21 +76,24 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
return act_peak_mem return act_peak_mem
def compute_max_param_mem(region_list: List[Region]) -> float: def compute_max_param_mem(region_list: List[Region]) -> float:
return max(region.param_size for region in region_list) return max(region.param_size for region in region_list)
def compute_total_param_mem(region_list: List[Region]) -> float: def compute_total_param_mem(region_list: List[Region]) -> float:
return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid) return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid)
def requires_upload_p_in_fwd(shared_reg: Region): def requires_upload_p_in_fwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or ( return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload) and shared_reg.need_offload)
def requires_release_p_in_bwd(shared_reg: Region): def requires_release_p_in_bwd(shared_reg: Region):
return (shared_reg.r_id >= shared_reg.shared_rid) or ( return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload) and shared_reg.need_offload)
def requires_offload_g_in_bwd(region: Region): def requires_offload_g_in_bwd(region: Region):
return region.param_size and (region.r_id <= region.shared_rid) return region.param_size and (region.r_id <= region.shared_rid)

View File

@ -1,46 +1,44 @@
import time import time
import pytest
from functools import partial from functools import partial
import pytest
import torch import torch
from torch.utils._pytree import tree_map
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils._pytree import tree_map
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.fx.profiler import parameter_size
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.utils import free_port, get_current_device
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.utils import free_port, get_current_device
from tests.test_tensor.common_utils import set_seed from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_auto_parallel.test_offload.model_utils import *
from tests.test_tensor.common_utils import set_seed
@parameterize('model_name', ['gpt2_']) @parameterize('model_name', ['gpt2_'])
@parameterize('memory_budget', [5000]) @parameterize('memory_budget', [5000])
@parameterize('solver_name', ['asyn']) @parameterize('solver_name', ['asyn'])
def exam_fwd_bwd( def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
model_name: str,
memory_budget: float,
solver_name: str
):
# build model # build model
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen = get_components_func() model_builder, data_gen = get_components_func()
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) label = torch.randint(low=0, high=128, size=(
64,
8,
), device=get_current_device())
criterion = LMLoss() criterion = LMLoss()
set_seed(42) set_seed(42)
start_time = time.time() start_time = time.time()
model = model_builder() model = model_builder()
model.train() model.train()
param_size = parameter_size(model) / 1024 ** 2 / 2 param_size = parameter_size(model) / 1024**2 / 2
init_time = time.time() - start_time init_time = time.time() - start_time
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
@ -92,13 +90,11 @@ def exam_fwd_bwd(
torch.cuda.synchronize() torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5 exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
print(f'gemini | model_name: {model_name}') print(f'gemini | model_name: {model_name}')
print( print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|')
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(time_list) print(time_list)
del data_args del data_args
@ -129,22 +125,26 @@ def exam_fwd_bwd(
torch.cuda.synchronize() torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5 exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
print(f'solver_name: {solver_name} | model_name: {model_name}') print(f'solver_name: {solver_name} | model_name: {model_name}')
print( print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|')
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(time_list) print(time_list)
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def test_perf(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_fwd_bwd() exam_fwd_bwd()
if __name__ == '__main__': @pytest.mark.skip("this test failed")
run_func = partial(test_perf, world_size=1, port=free_port()) @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def test_perf():
run_func = partial(run_dist, world_size=1, port=free_port())
mp.spawn(run_func, nprocs=1) mp.spawn(run_func, nprocs=1)
if __name__ == '__main__':
test_perf()

View File

@ -21,9 +21,6 @@ def check_gemini_plugin(early_stop: bool = True):
Args: Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
""" """
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
passed_models = [] passed_models = []
failed_info = {} # (model_name, error) pair failed_info = {} # (model_name, error) pair
@ -34,46 +31,23 @@ def check_gemini_plugin(early_stop: bool = True):
continue continue
# These models are not compatible with gemini # These models are not compatible with gemini
if name in [ if name in [
'diffusers_clip_vision_model', 'diffusers_clip_vision_model', 'timm_resnet', 'timm_beit', 'timm_beitv2', 'timm_eca_nfnet',
'timm_resnet', 'timm_efficientformer', 'timm_hrnet_w18_small', 'timm_nf_ecaresnet101', 'timm_nf_regnet_b0',
'timm_beit', 'timm_skresnet18', 'timm_wide_resnet50_2', 'timm_convit', 'timm_dm_nfnet', 'timm_swin_transformer',
'timm_beitv2', 'torchaudio_conformer', 'torchaudio_deepspeech', 'torchaudio_wavernn', 'torchaudio_tacotron',
'timm_eca_nfnet', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn', 'dlrm', 'dlrm_interactionarch',
'timm_efficientformer', 'torchvision_googlenet', 'torchvision_inception_v3', 'torchvision_mobilenet_v3_small',
'timm_hrnet_w18_small', 'torchvision_resnet18', 'torchvision_resnext50_32x4d', 'torchvision_wide_resnet50_2',
'timm_nf_ecaresnet101', 'torchvision_vit_b_16', 'torchvision_convnext_base', 'torchvision_swin_s', 'transformers_albert',
'timm_nf_regnet_b0', 'transformers_albert_for_pretraining', 'transformers_bert', 'transformers_bert_for_pretraining',
'timm_skresnet18', 'transformers_gpt_double_heads', 'torchaudio_hubert_base', 'torchaudio_wav2vec2_base',
'timm_wide_resnet50_2', 'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model'
'timm_convit',
'timm_dm_nfnet',
'timm_swin_transformer',
'torchaudio_conformer',
'torchaudio_deepspeech',
'torchaudio_wavernn',
'torchaudio_tacotron',
'deepfm_interactionarch',
'deepfm_simpledeepfmnn',
'dlrm',
'dlrm_interactionarch',
'torchvision_googlenet',
'torchvision_inception_v3',
'torchvision_mobilenet_v3_small',
'torchvision_resnet18',
'torchvision_resnext50_32x4d',
'torchvision_wide_resnet50_2',
'torchvision_vit_b_16',
'torchvision_convnext_base',
'torchvision_swin_s',
'transformers_albert',
'transformers_albert_for_pretraining',
'transformers_bert',
'transformers_bert_for_pretraining',
'transformers_gpt_double_heads',
'torchaudio_hubert_base',
]: ]:
continue continue
try: try:
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model = model_fn() model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean() criterion = lambda x: x.mean()
@ -97,10 +71,15 @@ def check_gemini_plugin(early_stop: bool = True):
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
passed_models.append(name) passed_models.append(name)
del booster, plugin, model, optimizer, criterion, data, output, loss
except Exception as e: except Exception as e:
failed_info[name] = e failed_info[name] = e
if early_stop: if early_stop:
raise e raise e
torch.cuda.empty_cache()
if dist.get_rank() == 0: if dist.get_rank() == 0:
print(f'Passed models({len(passed_models)}): {passed_models}\n\n') print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
@ -138,7 +117,6 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
check_gemini_plugin(early_stop=early_stop) check_gemini_plugin(early_stop=early_stop)
@pytest.mark.skip(reason='Skip gemini plugin test due to OOM')
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True): def test_gemini_plugin(early_stop: bool = True):
world_size = 2 world_size = 2

View File

@ -9,6 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
@ -176,6 +177,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_1_2(): def test_zero_1_2():
world_size = 2 world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())