diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 9ebbd48c7..860804f48 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -73,10 +73,11 @@ def chen_greedy(gm: GraphModule) -> GraphModule: y = 0 prev_idx = 2 for (idx, n) in enumerate(gm.graph.nodes): - temp += getattr(n, 'fwd_out') + n: Node + temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp'] y = max(y, temp) if temp > b and n in ckpt_nodes: - x += getattr(n, 'fwd_out') + x += n.meta['fwd_mem_out'] temp = 0 ckpt_intv.append((prev_idx, idx + 1)) prev_idx = idx + 1 diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 0d8ed9553..9cb48828e 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -1,11 +1,11 @@ -from typing import List, Set, Tuple, Dict +from typing import List, Tuple import torch from torch.fx import GraphModule, Node from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.profiler import parameter_size import math from .linearize import linearize from .utils import * -from colossalai.fx.profiler import profile_function, profile_module from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions @@ -25,8 +25,8 @@ def _compute_table(chain: Chain, mmax) -> Tuple: bw = chain.bweight ## backward time, not used cw = chain.cweight + [0] ## size of x (and of y) cbw = chain.cbweight + [0] ## size of xbar - fwd_tmp = chain.fwd_tmp + [0] - bwd_tmp = chain.bwd_tmp + [0] + fwd_mem_tmp = chain.fwd_mem_tmp + [0] + bwd_mem_tmp = chain.bwd_mem_tmp + [0] # Build table opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] @@ -37,7 +37,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple: for m in range(mmax + 1): for i in range(chain.length + 1): #lmax-lmin = 0 - limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) + limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i]) if m >= limit: ## Equation (1) opt[m][i][i] = fw[i] + bw[i] else: @@ -49,9 +49,9 @@ def _compute_table(chain: Chain, mmax) -> Tuple: for i in range(chain.length + 1 - d): # for idx in range(i+1, chain.length + 1): idx = i + d - mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i] + mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i] if idx > i + 1: - mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx))) + mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx))) if m < mmin: opt[m][i][idx] = float("inf") else: @@ -165,7 +165,7 @@ def _fwd_xbar(node: List[Node]) -> int: xbar = 0 for n in node: - xbar += n.fwd_tmp + n.fwd_out + xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out'] return xbar @@ -183,7 +183,7 @@ def _fwd_time(node: List[Node]) -> int: fwd_time = 0 for n in node: # minimum flop count is needed - fwd_time += max(n.fwd_flop, 1) + fwd_time += max(n.meta['fwd_flop'], 1) return fwd_time @@ -201,11 +201,11 @@ def _bwd_time(node: List[Node]) -> int: bwd_time = 0 for n in node: # minimum flop count is needed - bwd_time += max(n.bwd_flop, 1) + bwd_time += max(n.meta['bwd_flop'], 1) return bwd_time -def _get_bwd_tmp(node: List[Node]) -> int: +def _get_bwd_mem_tmp(node: List[Node]) -> int: """Get the backward temp memory of a node Args: @@ -218,29 +218,32 @@ def _get_bwd_tmp(node: List[Node]) -> int: def _get_deps_size(): deps_size = 0 - for key in deps.keys(): - deps_size += key.bwd_out + for k, v in deps.items(): + if v > 0: + deps_size += k.meta['bwd_mem_out'] return deps_size - bwd_tmp = 0 + bwd_mem_tmp = 0 deps = {} # add all the users for last node into deps, # as those nodes' gradient out will be stored in memory - for son in node[-1].users: - deps[son] = 1 + for child in node[-1].users: + deps[child] = 1 for n in reversed(node): - bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp) - deps[n] = len(n._input_nodes) - for son in n.users: - deps[son] -= 1 + bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp']) + + deps[n] = len(n.all_input_nodes) + for child in n.users: + if child in deps: + deps[child] -= 1 for key in list(deps.keys()): if deps[key] == 0: del deps[key] - return bwd_tmp + return bwd_mem_tmp def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: @@ -267,7 +270,7 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: bwd_time.append(_bwd_time(node)) x_sizes.append(_compute_output_size(node)) xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node))) - tmp_bwd.append(_get_bwd_tmp(node)) + tmp_bwd.append(_get_bwd_mem_tmp(node)) # if a node with only one inplace op, we need to let x_bar = 0 if len(node) == 1 and _get_inplace(node[0]): @@ -394,6 +397,7 @@ def solver_rotor(gm: ColoGraphModule, """ node_list = linearize(gm, cnode) + mem_limit -= parameter_size(gm) mem_unit = mem_limit * (1.0 - eps) // mem_slots MetaInfoProp(gm).run(data) chain: Chain = _construct_chain(node_list, data, mem_unit) diff --git a/colossalai/fx/passes/algorithms/utils.py b/colossalai/fx/passes/algorithms/utils.py index d26f1a2e2..78fb0c363 100644 --- a/colossalai/fx/passes/algorithms/utils.py +++ b/colossalai/fx/passes/algorithms/utils.py @@ -5,24 +5,24 @@ class Chain: self.bweight = bw self.cweight = cw self.cbweight = cbw - self.fwd_tmp = ftmp - self.bwd_tmp = btmp + self.fwd_mem_tmp = ftmp + self.bwd_mem_tmp = btmp self.length = len(fw) if check and not self.check_lengths(): raise AttributeError("In Chain, input lists do not have consistent lengths") def check_lengths(self): return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1) - and (len(self.cweight) == self.length + 1) and (len(self.fwd_tmp) == self.length) - and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1)) + and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length) + and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1)) def __repr__(self): chain_list = [] for i in range(self.length): - chain_list.append( - (self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_tmp[i], self.bwd_tmp[i])) + chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i], + self.bwd_mem_tmp[i])) i = self.length - chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_tmp[i])) + chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i])) return chain_list.__repr__() diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 1d2638a02..84efca13a 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -94,12 +94,11 @@ class MetaInfoProp(torch.fx.Interpreter): tensor_meta = tree_map(extract_tensor_meta, result) n.meta['tensor_meta'] = tensor_meta - n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` - + n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) for par in n.all_input_nodes: - par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0) + par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0)) n.meta['type'] = type(result) # retain the autograd graph diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index f6efbf312..69319b792 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,11 +1,12 @@ from dataclasses import dataclass from enum import Enum +from functools import partial from typing import Dict from torch.fx import Graph, Node from .memory import activation_size -class Stage(Enum): +class Phase(Enum): FORWARD = 0 LOSS = 1 BACKWARD = 2 @@ -48,24 +49,9 @@ class GraphInfo: bwd_mem_out: int = 0 -def is_forward(n: Node): - assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.FORWARD - - -def is_loss(n: Node): - assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.LOSS - - -def is_placeholder(n: Node): - assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.PLACEHOLDER - - -def is_backward(n: Node): - assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' - return n.meta['stage'] == Stage.BACKWARD +def is_phase(n: Node, phase: Phase) -> bool: + assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' + return n.meta['phase'] == phase def is_saved(n: Node): @@ -74,7 +60,7 @@ def is_saved(n: Node): def autograd_graph_analysis(graph: Graph) -> GraphInfo: """Analyze the autograd node dependencies and find out the memory usage. - Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. + Basically the input graph should have all nodes marked for keyword `phase`. Nodes should have attribute `out` indicating the output of each node. ============================================================================ Placeholder ----> p o <---- We need to keep track of grad out @@ -91,18 +77,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: l ============================================================================= Args: - graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`. + graph (Graph): The autograd graph with nodes marked for keyword `phase`. Returns: graph_info (GraphInfo): Meta information for the dataflow. """ def _peak_memory(deps: Dict[Node, int]): - bwd_tmp = 0 + peak_mem = 0 for k, v in deps.items(): if v > 0: - bwd_tmp += activation_size(k.meta['out']) - return bwd_tmp + peak_mem += activation_size(k.meta['out']) + return peak_mem # deps is used to track all the memory dependencies of the graph. deps = {} @@ -110,19 +96,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: for n in graph.nodes: n: Node - if is_saved(n) and not any(map(is_loss, n.users)): + if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)): # A forward tensor who is marked `save` but is not # an input to `loss` should be saved during forward. - # If the tensor is a placeholder, then it belongs to `fwd_in`. - # Any `fwd_in` should be kept in memory even this function + # If the tensor is a placeholder, then it belongs to `fwd_mem_in`. + # Any `fwd_mem_in` should be kept in memory even this function # is checkpointed. - # Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint - # the node, `fwd_tmp` can be freed. - if is_placeholder(n): + # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint + # the node, `fwd_mem_tmp` can be freed. + if is_phase(n, Phase.PLACEHOLDER): graph_info.fwd_mem_in += activation_size(n.meta['out']) - if is_forward(n): + if is_phase(n, Phase.FORWARD): graph_info.fwd_mem_tmp += activation_size(n.meta['out']) - elif is_backward(n): + elif is_phase(n, Phase.BACKWARD): if len(n.users): # liveness analysis is only used in backward deps[n] = len(n.users) diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 347c68c3a..a152385e8 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -5,8 +5,8 @@ import torch from torch.fx import Graph, Node from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map -from .dataflow import autograd_graph_analysis, Stage -from .memory import WEIRD_OPS +from .dataflow import GraphInfo, autograd_graph_analysis, Phase +from .memory import WEIRD_OPS, activation_size from .tensor import MetaTensor from .opcount import flop_mapping @@ -41,14 +41,11 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... # `flop_count`` serves as a global dictionary to store results. flop_count = { - Stage.FORWARD: 0, - Stage.LOSS: 0, - Stage.BACKWARD: 0, + Phase.FORWARD: 0, + Phase.LOSS: 0, + Phase.BACKWARD: 0, } - # `stage` will mark the stage of autograd from outside scope. - stage = Stage.FORWARD - # FlopTensor not only get the flop statistics of a single node, # it also build a full autograd graph for this node. # This makes sure we can analyze the dependencies of memory, and @@ -85,9 +82,9 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) - flop_count[stage] += flop_mapping[func](args, normalize_tuple(out)) + flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) node.meta['out'] = normalize_tuple(out) - node.meta['stage'] = stage + node.meta['phase'] = phase def wrap(x): return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x @@ -121,7 +118,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... x._node = subgraph.create_node('placeholder', 'placeholder', (subgraph._root,), name=subgraph._graph_namespace.create_name('input', x._tensor)) - x._node.meta['stage'] = Stage.PLACEHOLDER + x._node.meta['phase'] = Phase.PLACEHOLDER x._node.meta['out'] = (x._tensor,) tree_map(set_placeholder, args) @@ -135,6 +132,8 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... def unpack(x): return x + # `phase` will mark the phase of autograd from outside scope. + phase = Phase.FORWARD # mark saved tensors with saved_tensors_hooks with torch.autograd.graph.saved_tensors_hooks(pack, unpack): if isinstance(target, str): @@ -147,13 +146,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ... # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. if is_autogradable(out) and out.requires_grad: - stage = Stage.LOSS + phase = Phase.LOSS loss = out.sum() - stage = Stage.BACKWARD + phase = Phase.BACKWARD loss.backward() graph_info = autograd_graph_analysis(subgraph) - graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD] + graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD] def unwrap(x): return x._tensor.to('meta') if isinstance(x, FlopTensor) else x @@ -180,6 +179,11 @@ def profile_function(target: 'Target') -> Callable: # If there is an argument that this `call_function` is inplace, we should # skip the autograd profiling. + if kwargs.get('inplace', False): + args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) + kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) + out = func(*args, **kwargs) + return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0) out, meta = _profile(func, *args, **kwargs) return out, meta @@ -222,6 +226,11 @@ def profile_module(module: torch.nn.Module) -> Callable: # If there is an argument that this `call_module` is inplace, we should # skip the autograd profiling. + if getattr(module, 'inplace', False): + args = tree_map(lambda x: x.to('meta'), args) + kwargs = tree_map(lambda x: x.to('meta'), kwargs) + out = func(*args, **kwargs) + return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0) out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs) return out, meta diff --git a/tests/test_fx/test_ckpt_solvers/test_linearize.py b/tests/test_fx/test_ckpt_solvers/test_linearize.py index 1f4d4a0bc..4b6f91a4d 100644 --- a/tests/test_fx/test_ckpt_solvers/test_linearize.py +++ b/tests/test_fx/test_ckpt_solvers/test_linearize.py @@ -38,7 +38,8 @@ def test_linearize(): if isinstance(op, ForwardNograd): for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" continue @@ -54,7 +55,8 @@ def test_linearize(): ckpt_idx += 1 for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" continue @@ -63,7 +65,8 @@ def test_linearize(): in_ckpt = True for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + assert n.activation_checkpoint[ + 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" del model del gm