mirror of https://github.com/hpcaitech/ColossalAI
[fx] Fix wrong index in annotation and minimal flops in ckpt solver (#1521)
* [fx] fix wrong variable name in solver rotor * [fx] fix wrong variable name in solver rotor * [fx] fix the discretize bug * [fx] fix the first op in activation checkpoint codegen * [fx] fix some bugs of ckpt solver * [fx] modify test_ckpt_torchvision * [fx] set sequence to __sequence__ attr of GraphModule * [fx] docstring modification * [fx] remove performance testpull/1530/head
parent
07f5a4e054
commit
b231430bcb
|
@ -1,6 +1,7 @@
|
|||
from typing import List, Set, Tuple, Dict
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .utils import *
|
||||
|
@ -131,10 +132,10 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
|
|||
x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel *
|
||||
torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size())
|
||||
for node in node_dict[key]:
|
||||
fwd_time[-1] += node.__flops__
|
||||
fwd_time[-1] += max(node.__flops__, 1)
|
||||
|
||||
# currently we haven't patched the backward flops count
|
||||
bwd_time[-1] += node.__flops__ * 2
|
||||
bwd_time[-1] += max(node.__flops__ * 2, 2)
|
||||
|
||||
xbar_sizes[-1] += node.__activation__
|
||||
|
||||
|
@ -164,16 +165,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
|
|||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
in_ckpt = False
|
||||
for idx in ckpt_region:
|
||||
for node in node_dict[idx]:
|
||||
for node_idx in ckpt_region:
|
||||
for node in node_dict[node_idx]:
|
||||
setattr(node, "activation_checkpoint", ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for idx in ckpt_region:
|
||||
for node in node_dict[idx]:
|
||||
for node_idx in ckpt_region:
|
||||
for node in node_dict[node_idx]:
|
||||
setattr(node, "activation_checkpoint", ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
|
@ -185,7 +186,19 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
|
|||
ckpt_region.append(idx)
|
||||
|
||||
|
||||
def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> GraphModule:
|
||||
def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> ColoGraphModule:
|
||||
"""solver that automatically find activation checkpoint in rotor's manner
|
||||
|
||||
Args:
|
||||
gm (ColoGraphModule): ColoGraphModule generated by tracing model.
|
||||
data (torch.Tensor): input data.
|
||||
mem_limit (int): memory budget in Byte.
|
||||
mem_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
||||
|
||||
Returns:
|
||||
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
||||
"""
|
||||
|
||||
node_dict = linearize(gm)
|
||||
mem_unit = mem_limit // mem_slots
|
||||
MetaInfoProp(gm).run(data)
|
||||
|
@ -193,4 +206,7 @@ def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots:
|
|||
opt_table = _compute_table(chain, mem_slots)
|
||||
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
|
||||
_annotate_from_sequence(sequence, node_dict)
|
||||
|
||||
# set __sequence__ attribute to GraphModule
|
||||
setattr(gm, "__sequence__", sequence)
|
||||
return gm
|
||||
|
|
|
@ -62,13 +62,13 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
|
|||
|
||||
def _run_ckpt_solver(rank):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
||||
MODEL_LIST = [tm.densenet121]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
tracer = ColoTracer(trace_act_ckpt=False)
|
||||
|
||||
data = torch.rand(2, 3, 32, 32, device='meta')
|
||||
data = torch.rand(8, 3, 224, 224, device='meta')
|
||||
for solver in SOLVERS:
|
||||
for model_cls in MODEL_LIST:
|
||||
m = model_cls(num_classes=5)
|
||||
|
@ -95,13 +95,13 @@ def test_ckpt_solver():
|
|||
|
||||
def _run_ckpt_solver_torch11(rank):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
||||
MODEL_LIST = [tm.densenet121]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
tracer = ColoTracer(trace_act_ckpt=False)
|
||||
|
||||
data = torch.rand(2, 3, 32, 32, device='meta')
|
||||
data = torch.rand(8, 3, 32, 32, device='meta')
|
||||
for solver in SOLVERS:
|
||||
for model_cls in MODEL_LIST:
|
||||
m = model_cls(num_classes=5)
|
||||
|
|
Loading…
Reference in New Issue