mirror of https://github.com/hpcaitech/ColossalAI
[fx] fix wrong variable name in solver rotor (#1502)
* [fx] fix wrong variable name in solver rotor * [fx] fix wrong variable name in solver rotor * code modificationpull/1506/head^2
parent
3b6a5e2593
commit
31fffd3fc5
|
@ -73,7 +73,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
||||||
return (opt, what)
|
return (opt, what)
|
||||||
|
|
||||||
|
|
||||||
def _rec(chain, lmin, lmax, cmem, opt_table):
|
def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
|
||||||
""" chain : the class describing the AC graph
|
""" chain : the class describing the AC graph
|
||||||
lmin : index of the first forward to execute
|
lmin : index of the first forward to execute
|
||||||
lmax : upper bound index of the last forward to execute (not included)
|
lmax : upper bound index of the last forward to execute (not included)
|
||||||
|
@ -97,14 +97,14 @@ def _rec(chain, lmin, lmax, cmem, opt_table):
|
||||||
|
|
||||||
if what[cmem][lmin][lmax][0]:
|
if what[cmem][lmin][lmax][0]:
|
||||||
sequence.insert(ForwardEnable(lmin))
|
sequence.insert(ForwardEnable(lmin))
|
||||||
sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweigth[lmin + 1], opt_table))
|
sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table))
|
||||||
sequence.insert(Backward(lmin))
|
sequence.insert(Backward(lmin))
|
||||||
else:
|
else:
|
||||||
j = what[cmem][lmin][lmax][1]
|
j = what[cmem][lmin][lmax][1]
|
||||||
sequence.insert(ForwardCheck(lmin))
|
sequence.insert(ForwardCheck(lmin))
|
||||||
for k in range(lmin + 1, j):
|
for k in range(lmin + 1, j):
|
||||||
sequence.insert(ForwardNograd(k))
|
sequence.insert(ForwardNograd(k))
|
||||||
sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweigth[j], opt_table))
|
sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table))
|
||||||
sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table))
|
sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table))
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
|
@ -44,9 +44,9 @@ class Forward(Operation):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{n}_{i}".format(n=self.name, i=self.index)
|
return "{n}_{i}".format(n=self.name, i=self.index)
|
||||||
|
|
||||||
def cost(self, chain):
|
def cost(self, chain: Chain):
|
||||||
if chain is not None:
|
if chain is not None:
|
||||||
return chain.fweigth[self.index]
|
return chain.fweight[self.index]
|
||||||
else:
|
else:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
@ -80,9 +80,9 @@ class Forwards(Operation):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
|
return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
|
||||||
|
|
||||||
def cost(self, chain):
|
def cost(self, chain: Chain):
|
||||||
if chain is not None:
|
if chain is not None:
|
||||||
return sum(chain.fweigth[self.index[0]:self.index[1] + 1])
|
return sum(chain.fweight[self.index[0]:self.index[1] + 1])
|
||||||
else:
|
else:
|
||||||
return (self.index[1] - self.index[0] + 1)
|
return (self.index[1] - self.index[0] + 1)
|
||||||
|
|
||||||
|
@ -99,9 +99,9 @@ class Backward(Operation):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "B_{i}".format(i=self.index)
|
return "B_{i}".format(i=self.index)
|
||||||
|
|
||||||
def cost(self, chain):
|
def cost(self, chain: Chain):
|
||||||
if chain is not None:
|
if chain is not None:
|
||||||
return chain.bweigth[self.index]
|
return chain.bweight[self.index]
|
||||||
else:
|
else:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ class MemoryAccess(Operation):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{n}_{i}".format(n=self.name, i=self.index)
|
return "{n}_{i}".format(n=self.name, i=self.index)
|
||||||
|
|
||||||
def cost(self, chain):
|
def cost(self, chain: Chain):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue