[fx] Fix wrong index in annotation and minimal flops in ckpt solver (#1521)

* [fx] fix wrong variable name in solver rotor

* [fx] fix wrong variable name in solver rotor

* [fx] fix the discretize bug

* [fx] fix the first op in activation checkpoint codegen

* [fx] fix some bugs of ckpt solver

* [fx] modify test_ckpt_torchvision

* [fx] set sequence to __sequence__ attr of GraphModule

* [fx] docstring modification

* [fx] remove performance test
pull/1530/head
Boyuan Yao 2 years ago committed by GitHub
parent 07f5a4e054
commit b231430bcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
from typing import List, Set, Tuple, Dict
import torch
from torch.fx import GraphModule, Node
from colossalai.fx.graph_module import ColoGraphModule
import math
from .linearize import linearize
from .utils import *
@ -131,10 +132,10 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel *
torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size())
for node in node_dict[key]:
fwd_time[-1] += node.__flops__
fwd_time[-1] += max(node.__flops__, 1)
# currently we haven't patched the backward flops count
bwd_time[-1] += node.__flops__ * 2
bwd_time[-1] += max(node.__flops__ * 2, 2)
xbar_sizes[-1] += node.__activation__
@ -164,16 +165,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
elif isinstance(op, ForwardEnable):
in_ckpt = False
for idx in ckpt_region:
for node in node_dict[idx]:
for node_idx in ckpt_region:
for node in node_dict[node_idx]:
setattr(node, "activation_checkpoint", ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for idx in ckpt_region:
for node in node_dict[idx]:
for node_idx in ckpt_region:
for node in node_dict[node_idx]:
setattr(node, "activation_checkpoint", ckpt_idx)
ckpt_idx += 1
@ -185,7 +186,19 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
ckpt_region.append(idx)
def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> GraphModule:
def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> ColoGraphModule:
"""solver that automatically find activation checkpoint in rotor's manner
Args:
gm (ColoGraphModule): ColoGraphModule generated by tracing model.
data (torch.Tensor): input data.
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
"""
node_dict = linearize(gm)
mem_unit = mem_limit // mem_slots
MetaInfoProp(gm).run(data)
@ -193,4 +206,7 @@ def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots:
opt_table = _compute_table(chain, mem_slots)
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
_annotate_from_sequence(sequence, node_dict)
# set __sequence__ attribute to GraphModule
setattr(gm, "__sequence__", sequence)
return gm

@ -62,13 +62,13 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
def _run_ckpt_solver(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
MODEL_LIST = [tm.resnet18, tm.densenet121]
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
tracer = ColoTracer(trace_act_ckpt=False)
data = torch.rand(2, 3, 32, 32, device='meta')
data = torch.rand(8, 3, 224, 224, device='meta')
for solver in SOLVERS:
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)
@ -95,13 +95,13 @@ def test_ckpt_solver():
def _run_ckpt_solver_torch11(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
MODEL_LIST = [tm.resnet18, tm.densenet121]
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
tracer = ColoTracer(trace_act_ckpt=False)
data = torch.rand(2, 3, 32, 32, device='meta')
data = torch.rand(8, 3, 32, 32, device='meta')
for solver in SOLVERS:
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)

Loading…
Cancel
Save