2023-03-21 06:17:41 +00:00
|
|
|
from typing import Dict
|
2023-04-03 09:12:22 +00:00
|
|
|
|
2023-03-21 06:17:41 +00:00
|
|
|
import torch
|
|
|
|
import torch.fx
|
|
|
|
from torch.fx import GraphModule
|
|
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
|
|
from colossalai.fx import ColoTracer, is_compatible_with_meta
|
|
|
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
|
|
|
|
|
|
|
from .base_offload_module import BaseOffloadModule
|
2023-04-03 09:12:22 +00:00
|
|
|
from .region_manager import RegionManager
|
|
|
|
from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_pass
|
|
|
|
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem
|
|
|
|
|
2023-03-21 06:17:41 +00:00
|
|
|
|
|
|
|
def memory_optimize(model: torch.nn.Module,
|
|
|
|
inps: Dict[str, torch.Tensor],
|
|
|
|
memory_budget: float = -1.0,
|
|
|
|
solver_name: str = 'asyn'):
|
|
|
|
|
|
|
|
model = model.cpu().half()
|
|
|
|
tracer = ColoTracer()
|
|
|
|
assert is_compatible_with_meta()
|
|
|
|
wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x
|
|
|
|
meta_args = tree_map(wrap_fn, inps)
|
|
|
|
graph = tracer.trace(model, meta_args=meta_args)
|
|
|
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
|
|
|
interp = MetaInfoProp(gm)
|
|
|
|
interp.propagate(*meta_args.values())
|
|
|
|
|
|
|
|
region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget)
|
|
|
|
region_manager._build_regions()
|
2023-04-03 09:12:22 +00:00
|
|
|
GlobalRuntimeInfo().region_list = region_manager.region_list
|
2023-03-21 06:17:41 +00:00
|
|
|
|
2023-04-03 09:12:22 +00:00
|
|
|
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024**2
|
|
|
|
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024**2
|
|
|
|
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024**2
|
2023-03-21 06:17:41 +00:00
|
|
|
print(
|
2023-04-03 09:12:22 +00:00
|
|
|
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
|
|
|
|
)
|
2023-03-21 06:17:41 +00:00
|
|
|
|
|
|
|
if solver_name == 'syn':
|
|
|
|
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
|
|
|
|
elif solver_name == 'asyn':
|
|
|
|
gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list)
|
|
|
|
else:
|
|
|
|
raise TypeError(f"Unknown solver name {solver_name}!")
|
|
|
|
|
|
|
|
gm.recompile()
|
2023-04-03 09:12:22 +00:00
|
|
|
optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
|
2023-03-21 06:17:41 +00:00
|
|
|
return optimized_model
|