mirror of https://github.com/hpcaitech/ColossalAI
basic memory
parent
c35718e8db
commit
d95cfe2622
|
@ -6,6 +6,7 @@ from typing import List, Callable, Any, Tuple, Dict, Iterable
|
|||
try:
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
|
||||
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size
|
||||
CODEGEN_AVAILABLE = True
|
||||
except:
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
|
||||
|
@ -18,6 +19,82 @@ else:
|
|||
__all__ = ['python_code_with_activation_checkpoint']
|
||||
|
||||
|
||||
def _get_meta_node_size(x):
|
||||
x = x.meta['tensor_meta']
|
||||
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
||||
return x
|
||||
|
||||
|
||||
def _get_output_node_size(n):
|
||||
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
|
||||
return activation_size(fwd_out)
|
||||
|
||||
|
||||
def _get_delete_node_size(user, user_to_last_uses):
|
||||
if user.op in ('placeholder', 'output'):
|
||||
return 0
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(nodes_to_delete):
|
||||
delete_size = sum([_get_output_node_size(i) for i in nodes_to_delete])
|
||||
return delete_size
|
||||
return 0
|
||||
|
||||
|
||||
def _get_last_usr(nodes):
|
||||
node_to_last_use: Dict[Node, Node] = {}
|
||||
user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||
|
||||
def register_last_uses(n: Node, user: Node):
|
||||
if n not in node_to_last_use:
|
||||
node_to_last_use[n] = user
|
||||
user_to_last_uses.setdefault(user, []).append(n)
|
||||
|
||||
for node in reversed(nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
return user_to_last_uses
|
||||
|
||||
|
||||
def _estimate_inference_mem(gm: torch.fx.GraphModule):
|
||||
act_memory = 0
|
||||
act_memory_peak_log = []
|
||||
act_memory_after_node_log = []
|
||||
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
|
||||
for node in gm.graph.nodes:
|
||||
# if node is placeholder, just add the size of the node
|
||||
if node.op == 'placeholder':
|
||||
act_memory += _get_meta_node_size(node)
|
||||
# skip output
|
||||
elif node.op == 'output':
|
||||
continue
|
||||
# node is an operation, calculate tmp, output node and delete node memory
|
||||
else:
|
||||
# forward memory
|
||||
act_memory += calculate_fwd_tmp(node)
|
||||
# act_memory += calculate_fwd_out(node)
|
||||
act_memory += _get_output_node_size(node)
|
||||
# record max act memory
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# delete useless memory
|
||||
act_memory -= calculate_fwd_tmp(node)
|
||||
act_memory -= _get_delete_node_size(node, user_to_last_uses)
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
|
||||
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
|
||||
param_memory = parameter_size(gm)
|
||||
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2)
|
||||
|
||||
|
||||
def _estimate_chunk_forward_mem(gm: torch.fx.GraphModule, start_node, end_node, chunk_size):
|
||||
node_size = 0
|
||||
param_size = 0
|
||||
for node in gm.graph.nodes:
|
||||
node_size += calculate_fwd_tmp(node)
|
||||
node_size += calculate_fwd_out(node)
|
||||
param_size = parameter_size(gm)
|
||||
return (node_size + param_size) / 1024**2, param_size / 1024**2
|
||||
|
||||
|
||||
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
||||
new_shape = "["
|
||||
for idx, i in enumerate(shape):
|
||||
|
@ -342,7 +419,7 @@ def emit_ckpt_func(body,
|
|||
body.append(usage)
|
||||
|
||||
|
||||
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes):
|
||||
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph):
|
||||
"""Emit code with nested activation checkpoint
|
||||
When we detect some of the node.activation_checkpoint is a List, we will use
|
||||
this function to emit the activation checkpoint codes.
|
||||
|
@ -364,6 +441,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
within_chunk_region = False
|
||||
|
||||
node_list = list(nodes)
|
||||
_estimate_inference_mem(meta_graph)
|
||||
|
||||
# find the input and output var names for each offload region
|
||||
for idx, (start, end) in enumerate(chunk_regions):
|
||||
|
@ -418,6 +496,7 @@ if CODEGEN_AVAILABLE:
|
|||
class ChunkCodeGen(CodeGen):
|
||||
def __init__(self, meta_graph):
|
||||
super().__init__()
|
||||
self.meta_graph = meta_graph
|
||||
self.meta_node = list(meta_graph.graph.nodes)
|
||||
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
||||
|
@ -612,7 +691,7 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node)
|
||||
emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node, self.meta_graph)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
|
|
|
@ -2,6 +2,7 @@ import copy
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
|
@ -56,18 +57,15 @@ def _run_offload_codegen(rank):
|
|||
pair = torch.randn(1, 32, 32, 128).cuda()
|
||||
|
||||
# trace the module and replace codegen
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
graph = tracer.trace(model)
|
||||
gm_prop = torch.fx.GraphModule(model, graph)
|
||||
graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))})
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0'))
|
||||
|
||||
# annotate the chunk part
|
||||
# for node in graph.nodes:
|
||||
# if node.name == "linear0":
|
||||
# setattr(node, "activation_offload", [0, True, False])
|
||||
# if node.name == "linear1":
|
||||
# setattr(node, "activation_offload", [0, True, False])
|
||||
# now run it twice to get meta info in graph module, not necessary
|
||||
gm = torch.fx.GraphModule(model, graph)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0'))
|
||||
|
||||
codegen = ChunkCodeGen(gm_prop)
|
||||
graph.set_codegen(codegen)
|
||||
|
|
Loading…
Reference in New Issue