mirror of https://github.com/hpcaitech/ColossalAI
[fx] fix wrong variable name in solver rotor (#1502)
* [fx] fix wrong variable name in solver rotor * [fx] fix wrong variable name in solver rotor * code modificationpull/1506/head^2
parent
3b6a5e2593
commit
31fffd3fc5
|
@ -73,7 +73,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
|||
return (opt, what)
|
||||
|
||||
|
||||
def _rec(chain, lmin, lmax, cmem, opt_table):
|
||||
def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
|
||||
""" chain : the class describing the AC graph
|
||||
lmin : index of the first forward to execute
|
||||
lmax : upper bound index of the last forward to execute (not included)
|
||||
|
@ -97,14 +97,14 @@ def _rec(chain, lmin, lmax, cmem, opt_table):
|
|||
|
||||
if what[cmem][lmin][lmax][0]:
|
||||
sequence.insert(ForwardEnable(lmin))
|
||||
sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweigth[lmin + 1], opt_table))
|
||||
sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table))
|
||||
sequence.insert(Backward(lmin))
|
||||
else:
|
||||
j = what[cmem][lmin][lmax][1]
|
||||
sequence.insert(ForwardCheck(lmin))
|
||||
for k in range(lmin + 1, j):
|
||||
sequence.insert(ForwardNograd(k))
|
||||
sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweigth[j], opt_table))
|
||||
sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table))
|
||||
sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table))
|
||||
return sequence
|
||||
|
||||
|
|
|
@ -44,9 +44,9 @@ class Forward(Operation):
|
|||
def __repr__(self):
|
||||
return "{n}_{i}".format(n=self.name, i=self.index)
|
||||
|
||||
def cost(self, chain):
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return chain.fweigth[self.index]
|
||||
return chain.fweight[self.index]
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
@ -80,9 +80,9 @@ class Forwards(Operation):
|
|||
def __repr__(self):
|
||||
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
|
||||
|
||||
def cost(self, chain):
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return sum(chain.fweigth[self.index[0]:self.index[1] + 1])
|
||||
return sum(chain.fweight[self.index[0]:self.index[1] + 1])
|
||||
else:
|
||||
return (self.index[1] - self.index[0] + 1)
|
||||
|
||||
|
@ -99,9 +99,9 @@ class Backward(Operation):
|
|||
def __repr__(self):
|
||||
return "B_{i}".format(i=self.index)
|
||||
|
||||
def cost(self, chain):
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return chain.bweigth[self.index]
|
||||
return chain.bweight[self.index]
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
@ -126,7 +126,7 @@ class MemoryAccess(Operation):
|
|||
def __repr__(self):
|
||||
return "{n}_{i}".format(n=self.name, i=self.index)
|
||||
|
||||
def cost(self, chain):
|
||||
def cost(self, chain: Chain):
|
||||
return 0
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue