From cc55ff0aa41d7dfddf040598cea3c41bcc35ac5a Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Thu, 10 Nov 2022 20:59:28 +0800 Subject: [PATCH] [autoparallel] user-friendly API for CheckpointSolver. (#1879) Merge for SC tutorial --- .../checkpoint/ckpt_solver_base.py | 16 ++++++++++------ .../checkpoint/ckpt_solver_rotor.py | 17 ++++++----------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index 591f5fd25..63eff31b2 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from typing import Any, List +import torch from torch.fx import Graph, Node from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen @@ -17,13 +18,17 @@ def _copy_output(src: Graph, dst: Graph): n_dst.meta = n_src.meta +def _get_param_size(module: torch.nn.Module): + """Get the size of the parameters in the module""" + return sum([p.numel() * torch.tensor([], dtype=p.dtype).element_size() for p in module.parameters()]) + + class CheckpointSolverBase(ABC): def __init__( self, graph: Graph, - memory_budget: float = -1.0, - parameter_size: float = 0, + free_memory: float = -1.0, requires_linearize: bool = False, cnode: List[str] = None, ): @@ -37,8 +42,7 @@ class CheckpointSolverBase(ABC): Args: graph (Graph): The computing graph to be optimized. - memory_budget (float): Memory constraint for the solution. - parameter_size (float): The size of parameter of this model. Use `parameter_size(model)` to estimate. + free_memory (float): Memory constraint for the solution. requires_linearize (bool): Whether the graph needs to be linearized. cnode (List[str], optional): Common node List, should be the subset of input. Default to None. @@ -58,8 +62,8 @@ class CheckpointSolverBase(ABC): raise RuntimeError( "Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!") - self.memory_budget = memory_budget - self.parameter_size = parameter_size + self.free_memory = free_memory + self.parameter_size = _get_param_size(self.graph.owning_module) self.cnode = cnode self.requires_linearize = requires_linearize if self.requires_linearize: diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index 22dbc8be0..72bc67e02 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -22,12 +22,7 @@ __all__ = ['CheckpointSolverRotor'] class CheckpointSolverRotor(CheckpointSolverBase): - def __init__(self, - graph: Graph, - memory_budget: float = -1, - parameter_size: float = 0, - cnode: List[str] = None, - memory_slots: int = 500): + def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500): """This is the simple implementation of dynamic programming algorithm rotor in https://hal.inria.fr/hal-02352969. Some code are adapted from https://gitlab.inria.fr/hiepacs/rotor. @@ -36,22 +31,22 @@ class CheckpointSolverRotor(CheckpointSolverBase): Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` to the graph to retrieve all information needed, then we could use the following code to find a solution using `CheckpointSolverRotor`: - >>> solver = CheckpointSolverRotor(gm.graph, memory_budget=memory_budget, parameter_size=parameter_size) + >>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0]) >>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver >>> gm.graph = rotor_graph # set the graph to a new graph Args: graph (Graph): The computing graph to be optimized. - memory_budget (float, optional): Memory constraint for the solution, unit is byte. - parameter_size (float, optional): The size of parameter of this model, unit is byte. Use `parameter_size(model)` to estimate. + free_memory (float, optional): Memory constraint for the solution, unit is byte. + Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1. cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. """ - super().__init__(graph, memory_budget, parameter_size, True, cnode) + super().__init__(graph, free_memory, True, cnode) self.memory_slots = memory_slots # construct chain - unit = self.memory_budget // self.memory_slots + unit = self.free_memory // self.memory_slots self.chain = self._construct_chain(self.graph, self.node_list) self.chain.discretize_all(unit)