diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index eeb43f3a7..b72d20fd2 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -114,12 +114,57 @@ def _discretize(mem_unit, values): return [math.ceil(value / mem_unit) for value in values] -def _construct_chain(node_list: List[List[Node]], data: torch.Tensor, mem_unit: int) -> Chain: +def _compute_size(obj: torch.Tensor) -> int: + return obj.numel() * obj.element_size() + + +def _compute_output_size(node: List[Node]) -> int: + """Compute the output size of a node + + Args: + node (List[Node]): node, list of torch.fx.Node + + Returns: + int: output size + """ + + return node[-1].meta['tensor_meta'].numel * torch.tensor([], + dtype=node[-1].meta['tensor_meta'].dtype).element_size() + + +def _get_inplace(node: Node) -> bool: + """Get the inplace argument from torch.fx.Node + + Args: + node (Node): torch.fx.Node + + Returns: + bool: indicates whether this op is inplace + """ + + is_inplace = False + if node.op == "call_function": + is_inplace = node.kwargs.get("inplace", False) + elif node.op == "call_module": + is_inplace = getattr(node.graph.owning_module.get_submodule(node.target), "inplace", False) + + return is_inplace + + +def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: fwd_time = [] bwd_time = [] - xbar_sizes = [data.numel() * data.element_size()] - x_sizes = [data.numel() * data.element_size()] + + if isinstance(data, torch.Tensor): + xbar_sizes = [_compute_size(data)] + x_sizes = [_compute_size(data)] + elif isinstance(data, list) or isinstance(data, tuple): + xbar_sizes = [_compute_size(obj) for obj in data] + x_sizes = [_compute_size(obj) for obj in data] + elif isinstance(data, dict): + xbar_sizes = [_compute_size(obj) for obj in data.values()] + x_sizes = [_compute_size(obj) for obj in data.values()] # currently we can't get the temp memory needed in fwd and bwd tmp_fwd = [0] * len(node_list) @@ -129,16 +174,27 @@ def _construct_chain(node_list: List[List[Node]], data: torch.Tensor, mem_unit: fwd_time.append(0) bwd_time.append(0) xbar_sizes.append(0) - x_sizes.append(node[-1].meta['tensor_meta'].numel * - torch.tensor([], dtype=node[-1].meta['tensor_meta'].dtype).element_size()) + x_sizes.append(_compute_output_size(node)) + + _check_inplace_flag = 1 for n in node: fwd_time[-1] += max(n.__flops__, 1) # currently we haven't patched the backward flops count bwd_time[-1] += max(n.__flops__ * 2, 2) - xbar_sizes[-1] += n.__activation__ + # we need to clear the xbar of previous node as there is + # one op in the current node that use the previous node's + # output but applies inplace operation on it + # NOTE: This process should be done only once as the previous + # node will only have one output + if _check_inplace_flag: + for par in n._input_nodes: + if par not in node and _get_inplace(n): + xbar_sizes[-2] -= x_sizes[-2] + _check_inplace_flag = 0 + xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1]) bwd_time.append(0) @@ -186,20 +242,25 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> ckpt_region.append(idx) -def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> ColoGraphModule: +def solver_rotor(gm: ColoGraphModule, + data, + mem_limit: int, + mem_slots: int = 500, + cnode: List[str] = None) -> ColoGraphModule: """solver that automatically find activation checkpoint in rotor's manner Args: gm (ColoGraphModule): ColoGraphModule generated by tracing model. data (torch.Tensor): input data. mem_limit (int): memory budget in Byte. - mem_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. + mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500. + cnode (List[Node], optional): common node list for linearize. Defaults to None. Returns: ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute """ - node_list = linearize(gm) + node_list = linearize(gm, cnode) mem_unit = mem_limit // mem_slots MetaInfoProp(gm).run(data) chain: Chain = _construct_chain(node_list, data, mem_unit) diff --git a/colossalai/fx/passes/algorithms/linearize.py b/colossalai/fx/passes/algorithms/linearize.py index e6b47a7ba..f8c531356 100644 --- a/colossalai/fx/passes/algorithms/linearize.py +++ b/colossalai/fx/passes/algorithms/linearize.py @@ -2,11 +2,12 @@ from typing import List from torch.fx import GraphModule, Node -def linearize(gm: GraphModule) -> List[List[Node]]: +def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]: """Linearizing the graph Args: gm (GraphModule): GraphModule derived by tracing + cnode (List[str], optional): common node List, should be the subset of input. Default to None. Returns: List[List[Node]]: List of list, each inside list of Node presents @@ -22,23 +23,39 @@ def linearize(gm: GraphModule) -> List[List[Node]]: return not sum([v for _, v in deps.items()]) + # make sure that item in cnode is valid + if cnode: + for name in cnode: + try: + assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \ + f"common node {name} is not an input of the model" + except StopIteration: + raise ValueError(f"common node name {name} not in graph") + + else: + cnode = [] + deps = {} linearized_nodes = [] region = [] for n in gm.graph.nodes: - for n_par in n._input_nodes: - deps[n_par] -= 1 - region.append(n) + if n.op != "placeholder" and n.op != "output": + for n_par in n._input_nodes: + if n_par.op != "placeholder" and n_par.name not in cnode: + deps[n_par] -= 1 + region.append(n) - # if the node could free all dependencies in graph - # we could begin a new node - if _is_sink(): - linearized_nodes.append(region) - region = [] + # if the node could free all dependencies in graph + # we could begin a new node + if _is_sink(): + linearized_nodes.append(region) + region = [] - deps[n] = len(n.users) + # propagate common node attr if possible + if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]): + cnode.append(n.name) + else: + deps[n] = len([user for user in n.users if user.op != "output"]) - # Remove input - linearized_nodes = linearized_nodes[1:-1] return linearized_nodes