mirror of https://github.com/hpcaitech/ColossalAI
[fx] Fix activation codegen dealing with checkpointing first op (#1510)
parent
ac3a453a50
commit
4acc58ee20
|
@ -165,9 +165,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||||
|
|
||||||
# we need to check if the checkpoint need use_reentrant=False
|
# we need to check if the checkpoint need use_reentrant=False
|
||||||
use_reentrant = True
|
use_reentrant = True
|
||||||
|
non_leaf_input = 0
|
||||||
for var in input_vars[label]:
|
for var in input_vars[label]:
|
||||||
input_node = [item for item in node_list if item.name == var]
|
input_node = [item for item in node_list if item.name == var]
|
||||||
input_node = input_node[0]
|
input_node = input_node[0]
|
||||||
|
if input_node.op != "placeholder":
|
||||||
|
non_leaf_input = 1
|
||||||
for user in input_node.users:
|
for user in input_node.users:
|
||||||
if hasattr(user, "activation_checkpoint"):
|
if hasattr(user, "activation_checkpoint"):
|
||||||
if user.activation_checkpoint == label:
|
if user.activation_checkpoint == label:
|
||||||
|
@ -179,6 +182,10 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||||
if "inplace" in user.kwargs:
|
if "inplace" in user.kwargs:
|
||||||
use_reentrant = not user.kwargs["inplace"]
|
use_reentrant = not user.kwargs["inplace"]
|
||||||
|
|
||||||
|
# if all the inputs are leaf nodes, we need to set use_reentrant = False
|
||||||
|
if not non_leaf_input:
|
||||||
|
use_reentrant = False
|
||||||
|
|
||||||
# generate checkpoint function call in a new line
|
# generate checkpoint function call in a new line
|
||||||
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
|
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
|
||||||
usage += '\n'
|
usage += '\n'
|
||||||
|
|
|
@ -49,16 +49,20 @@ class MyModule(torch.nn.Module):
|
||||||
self.relu = relu()
|
self.relu = relu()
|
||||||
self.linear2 = torch.nn.Linear(4, 4)
|
self.linear2 = torch.nn.Linear(4, 4)
|
||||||
|
|
||||||
def forward(self, x):
|
def ckpt2(self, x):
|
||||||
|
return F.relu(x, inplace=True)
|
||||||
|
|
||||||
|
def ckpt3(self, x, y):
|
||||||
|
return self.linear2(x) + self.linear2(y)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
y1, y2 = checkpoint(self.mlp1, x)
|
y1, y2 = checkpoint(self.mlp1, x)
|
||||||
y3 = checkpoint(self.relu, x)
|
y3 = checkpoint(self.relu, x)
|
||||||
|
|
||||||
def ckpt2(x):
|
y4 = checkpoint(self.ckpt2, y)
|
||||||
return F.relu(x, inplace=True)
|
y5 = checkpoint(self.ckpt3, y, y4)
|
||||||
|
y6 = self.linear2(y4)
|
||||||
y4 = checkpoint(ckpt2, x)
|
return y1 + y2 + y3 + y4 + y5 + y6
|
||||||
y4 = self.linear2(y4)
|
|
||||||
return y1 + y2 + y3 + y4
|
|
||||||
|
|
||||||
|
|
||||||
def _run_act_ckpt_codegen(rank):
|
def _run_act_ckpt_codegen(rank):
|
||||||
|
@ -67,13 +71,15 @@ def _run_act_ckpt_codegen(rank):
|
||||||
|
|
||||||
# build model and run forward
|
# build model and run forward
|
||||||
model = MyModule()
|
model = MyModule()
|
||||||
data = torch.rand(4, 4)
|
data1 = torch.rand(4, 4)
|
||||||
|
data2 = torch.rand(4, 4)
|
||||||
|
|
||||||
# copy model to cuda
|
# copy model to cuda
|
||||||
model = model.to(device="cuda")
|
model = model.to(device="cuda")
|
||||||
data = data.to(device="cuda")
|
data1 = data1.to(device="cuda")
|
||||||
|
data2 = data2.to(device="cuda")
|
||||||
|
|
||||||
non_fx_out = model(data)
|
non_fx_out = model(data1, data2)
|
||||||
|
|
||||||
# trace the module and replace codegen
|
# trace the module and replace codegen
|
||||||
tracer = ColoTracer(trace_act_ckpt=True)
|
tracer = ColoTracer(trace_act_ckpt=True)
|
||||||
|
@ -99,12 +105,13 @@ def _run_act_ckpt_codegen(rank):
|
||||||
# assert checkpoint function will be generated and
|
# assert checkpoint function will be generated and
|
||||||
# the offload option is correct
|
# the offload option is correct
|
||||||
code = graph.python_code('self').src
|
code = graph.python_code('self').src
|
||||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
|
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \
|
||||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code
|
||||||
|
|
||||||
# recompile and verify the outputs are consistent
|
# recompile and verify the outputs are consistent
|
||||||
fx_out = gm(data)
|
fx_out = gm(data1, data2)
|
||||||
assert torch.equal(non_fx_out, fx_out)
|
assert torch.equal(non_fx_out, fx_out)
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
|
@ -121,13 +128,14 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||||
|
|
||||||
# build model and run forward
|
# build model and run forward
|
||||||
model = MyModule()
|
model = MyModule()
|
||||||
data = torch.rand(4, 4)
|
data1 = torch.rand(4, 4)
|
||||||
|
data2 = torch.rand(4, 4)
|
||||||
|
|
||||||
# copy model to cuda
|
# copy model to cuda
|
||||||
model = model.to(device="cuda")
|
data1 = data1.to(device="cuda")
|
||||||
data = data.to(device="cuda")
|
data2 = data2.to(device="cuda")
|
||||||
|
|
||||||
non_fx_out = model(data)
|
non_fx_out = model(data1, data2)
|
||||||
|
|
||||||
# trace the module and replace codegen
|
# trace the module and replace codegen
|
||||||
tracer = ColoTracer(trace_act_ckpt=True)
|
tracer = ColoTracer(trace_act_ckpt=True)
|
||||||
|
@ -152,12 +160,13 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||||
# assert checkpoint function will be generated and
|
# assert checkpoint function will be generated and
|
||||||
# the offload option is correct
|
# the offload option is correct
|
||||||
code = graph.python_code('self').src
|
code = graph.python_code('self').src
|
||||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
|
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \
|
||||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code
|
||||||
|
|
||||||
# recompile and verify the outputs are consistent
|
# recompile and verify the outputs are consistent
|
||||||
fx_out = gm(data)
|
fx_out = gm(data1, data2)
|
||||||
assert torch.equal(non_fx_out, fx_out)
|
assert torch.equal(non_fx_out, fx_out)
|
||||||
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
|
|
Loading…
Reference in New Issue