mirror of https://github.com/hpcaitech/ColossalAI
[fx] Fix wrong index in annotation and minimal flops in ckpt solver (#1521)
* [fx] fix wrong variable name in solver rotor * [fx] fix wrong variable name in solver rotor * [fx] fix the discretize bug * [fx] fix the first op in activation checkpoint codegen * [fx] fix some bugs of ckpt solver * [fx] modify test_ckpt_torchvision * [fx] set sequence to __sequence__ attr of GraphModule * [fx] docstring modification * [fx] remove performance testpull/1530/head
parent
07f5a4e054
commit
b231430bcb
|
@ -1,6 +1,7 @@
|
||||||
from typing import List, Set, Tuple, Dict
|
from typing import List, Set, Tuple, Dict
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import GraphModule, Node
|
from torch.fx import GraphModule, Node
|
||||||
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
import math
|
import math
|
||||||
from .linearize import linearize
|
from .linearize import linearize
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
@ -131,10 +132,10 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
|
||||||
x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel *
|
x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel *
|
||||||
torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size())
|
torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size())
|
||||||
for node in node_dict[key]:
|
for node in node_dict[key]:
|
||||||
fwd_time[-1] += node.__flops__
|
fwd_time[-1] += max(node.__flops__, 1)
|
||||||
|
|
||||||
# currently we haven't patched the backward flops count
|
# currently we haven't patched the backward flops count
|
||||||
bwd_time[-1] += node.__flops__ * 2
|
bwd_time[-1] += max(node.__flops__ * 2, 2)
|
||||||
|
|
||||||
xbar_sizes[-1] += node.__activation__
|
xbar_sizes[-1] += node.__activation__
|
||||||
|
|
||||||
|
@ -164,16 +165,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
|
||||||
|
|
||||||
elif isinstance(op, ForwardEnable):
|
elif isinstance(op, ForwardEnable):
|
||||||
in_ckpt = False
|
in_ckpt = False
|
||||||
for idx in ckpt_region:
|
for node_idx in ckpt_region:
|
||||||
for node in node_dict[idx]:
|
for node in node_dict[node_idx]:
|
||||||
setattr(node, "activation_checkpoint", ckpt_idx)
|
setattr(node, "activation_checkpoint", ckpt_idx)
|
||||||
|
|
||||||
ckpt_idx += 1
|
ckpt_idx += 1
|
||||||
ckpt_region = []
|
ckpt_region = []
|
||||||
|
|
||||||
elif isinstance(op, ForwardCheck):
|
elif isinstance(op, ForwardCheck):
|
||||||
for idx in ckpt_region:
|
for node_idx in ckpt_region:
|
||||||
for node in node_dict[idx]:
|
for node in node_dict[node_idx]:
|
||||||
setattr(node, "activation_checkpoint", ckpt_idx)
|
setattr(node, "activation_checkpoint", ckpt_idx)
|
||||||
|
|
||||||
ckpt_idx += 1
|
ckpt_idx += 1
|
||||||
|
@ -185,7 +186,19 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
|
||||||
ckpt_region.append(idx)
|
ckpt_region.append(idx)
|
||||||
|
|
||||||
|
|
||||||
def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> GraphModule:
|
def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> ColoGraphModule:
|
||||||
|
"""solver that automatically find activation checkpoint in rotor's manner
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gm (ColoGraphModule): ColoGraphModule generated by tracing model.
|
||||||
|
data (torch.Tensor): input data.
|
||||||
|
mem_limit (int): memory budget in Byte.
|
||||||
|
mem_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
||||||
|
"""
|
||||||
|
|
||||||
node_dict = linearize(gm)
|
node_dict = linearize(gm)
|
||||||
mem_unit = mem_limit // mem_slots
|
mem_unit = mem_limit // mem_slots
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm).run(data)
|
||||||
|
@ -193,4 +206,7 @@ def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots:
|
||||||
opt_table = _compute_table(chain, mem_slots)
|
opt_table = _compute_table(chain, mem_slots)
|
||||||
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
|
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
|
||||||
_annotate_from_sequence(sequence, node_dict)
|
_annotate_from_sequence(sequence, node_dict)
|
||||||
|
|
||||||
|
# set __sequence__ attribute to GraphModule
|
||||||
|
setattr(gm, "__sequence__", sequence)
|
||||||
return gm
|
return gm
|
||||||
|
|
|
@ -62,13 +62,13 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
|
||||||
|
|
||||||
def _run_ckpt_solver(rank):
|
def _run_ckpt_solver(rank):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
MODEL_LIST = [tm.densenet121]
|
||||||
|
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
tracer = ColoTracer(trace_act_ckpt=False)
|
tracer = ColoTracer(trace_act_ckpt=False)
|
||||||
|
|
||||||
data = torch.rand(2, 3, 32, 32, device='meta')
|
data = torch.rand(8, 3, 224, 224, device='meta')
|
||||||
for solver in SOLVERS:
|
for solver in SOLVERS:
|
||||||
for model_cls in MODEL_LIST:
|
for model_cls in MODEL_LIST:
|
||||||
m = model_cls(num_classes=5)
|
m = model_cls(num_classes=5)
|
||||||
|
@ -95,13 +95,13 @@ def test_ckpt_solver():
|
||||||
|
|
||||||
def _run_ckpt_solver_torch11(rank):
|
def _run_ckpt_solver_torch11(rank):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
MODEL_LIST = [tm.densenet121]
|
||||||
|
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
tracer = ColoTracer(trace_act_ckpt=False)
|
tracer = ColoTracer(trace_act_ckpt=False)
|
||||||
|
|
||||||
data = torch.rand(2, 3, 32, 32, device='meta')
|
data = torch.rand(8, 3, 32, 32, device='meta')
|
||||||
for solver in SOLVERS:
|
for solver in SOLVERS:
|
||||||
for model_cls in MODEL_LIST:
|
for model_cls in MODEL_LIST:
|
||||||
m = model_cls(num_classes=5)
|
m = model_cls(num_classes=5)
|
||||||
|
|
Loading…
Reference in New Issue