diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 6c1c2ae82..1bbb1418f 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -373,10 +373,10 @@ if codegen_available: code = ''.join(body) code = '\n'.join(' ' + line for line in code.split('\n')) fn_code = f""" - {wrap_stmts} +{wrap_stmts} - {prologue} - {code}""" +{prologue} +{code}""" return PythonCode(fn_code, globals_) else: