@ -1,6 +1,7 @@
from typing import List , Set , Tuple , Dict
import torch
from torch . fx import GraphModule , Node
from colossalai . fx . graph_module import ColoGraphModule
import math
from . linearize import linearize
from . utils import *
@ -131,10 +132,10 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
x_sizes . append ( node_dict [ key ] [ - 1 ] . meta [ ' tensor_meta ' ] . numel *
torch . tensor ( [ ] , dtype = node_dict [ key ] [ - 1 ] . meta [ ' tensor_meta ' ] . dtype ) . element_size ( ) )
for node in node_dict [ key ] :
fwd_time [ - 1 ] + = node . __flops__
fwd_time [ - 1 ] + = max ( node . __flops__ , 1 )
# currently we haven't patched the backward flops count
bwd_time [ - 1 ] + = node . __flops__ * 2
bwd_time [ - 1 ] + = max ( node . __flops__ * 2 , 2 )
xbar_sizes [ - 1 ] + = node . __activation__
@ -164,16 +165,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
elif isinstance ( op , ForwardEnable ) :
in_ckpt = False
for idx in ckpt_region :
for node in node_dict [ idx ] :
for node_ idx in ckpt_region :
for node in node_dict [ node_ idx] :
setattr ( node , " activation_checkpoint " , ckpt_idx )
ckpt_idx + = 1
ckpt_region = [ ]
elif isinstance ( op , ForwardCheck ) :
for idx in ckpt_region :
for node in node_dict [ idx ] :
for node_ idx in ckpt_region :
for node in node_dict [ node_ idx] :
setattr ( node , " activation_checkpoint " , ckpt_idx )
ckpt_idx + = 1
@ -185,7 +186,19 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
ckpt_region . append ( idx )
def solver_rotor ( gm : GraphModule , data : torch . Tensor , mem_limit : int , mem_slots : int = 500 ) - > GraphModule :
def solver_rotor ( gm : ColoGraphModule , data : torch . Tensor , mem_limit : int , mem_slots : int = 500 ) - > ColoGraphModule :
""" solver that automatically find activation checkpoint in rotor ' s manner
Args :
gm ( ColoGraphModule ) : ColoGraphModule generated by tracing model .
data ( torch . Tensor ) : input data .
mem_limit ( int ) : memory budget in Byte .
mem_slots ( int , optional ) : Number of slots for discretizing memory budget . Defaults to 500.
Returns :
ColoGraphModule : annotated ColoGraphModuled with __sequence__ attribute
"""
node_dict = linearize ( gm )
mem_unit = mem_limit / / mem_slots
MetaInfoProp ( gm ) . run ( data )
@ -193,4 +206,7 @@ def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots:
opt_table = _compute_table ( chain , mem_slots )
sequence = _rec ( chain , 0 , chain . length , mem_slots - chain . cweight [ 0 ] , opt_table )
_annotate_from_sequence ( sequence , node_dict )
# set __sequence__ attribute to GraphModule
setattr ( gm , " __sequence__ " , sequence )
return gm