diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 396cf7b29..f7de4987c 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -73,7 +73,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple: return (opt, what) -def _rec(chain, lmin, lmax, cmem, opt_table): +def _rec(chain: Chain, lmin, lmax, cmem, opt_table): """ chain : the class describing the AC graph lmin : index of the first forward to execute lmax : upper bound index of the last forward to execute (not included) @@ -97,14 +97,14 @@ def _rec(chain, lmin, lmax, cmem, opt_table): if what[cmem][lmin][lmax][0]: sequence.insert(ForwardEnable(lmin)) - sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweigth[lmin + 1], opt_table)) + sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table)) sequence.insert(Backward(lmin)) else: j = what[cmem][lmin][lmax][1] sequence.insert(ForwardCheck(lmin)) for k in range(lmin + 1, j): sequence.insert(ForwardNograd(k)) - sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweigth[j], opt_table)) + sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table)) sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table)) return sequence diff --git a/colossalai/fx/passes/algorithms/utils.py b/colossalai/fx/passes/algorithms/utils.py index 88efe0a0c..d26f1a2e2 100644 --- a/colossalai/fx/passes/algorithms/utils.py +++ b/colossalai/fx/passes/algorithms/utils.py @@ -44,9 +44,9 @@ class Forward(Operation): def __repr__(self): return "{n}_{i}".format(n=self.name, i=self.index) - def cost(self, chain): + def cost(self, chain: Chain): if chain is not None: - return chain.fweigth[self.index] + return chain.fweight[self.index] else: return 1 @@ -80,9 +80,9 @@ class Forwards(Operation): def __repr__(self): return "F_{i}->{j}".format(i=self.index[0], j=self.index[1]) - def cost(self, chain): + def cost(self, chain: Chain): if chain is not None: - return sum(chain.fweigth[self.index[0]:self.index[1] + 1]) + return sum(chain.fweight[self.index[0]:self.index[1] + 1]) else: return (self.index[1] - self.index[0] + 1) @@ -99,9 +99,9 @@ class Backward(Operation): def __repr__(self): return "B_{i}".format(i=self.index) - def cost(self, chain): + def cost(self, chain: Chain): if chain is not None: - return chain.bweigth[self.index] + return chain.bweight[self.index] else: return 1 @@ -126,7 +126,7 @@ class MemoryAccess(Operation): def __repr__(self): return "{n}_{i}".format(n=self.name, i=self.index) - def cost(self, chain): + def cost(self, chain: Chain): return 0