mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/nn/layer/utils/common.py code style (#983)
parent
bda70b4b66
commit
571f12eff3
|
@ -13,7 +13,8 @@ from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class CheckpointModule(nn.Module):
|
class CheckpointModule(nn.Module):
|
||||||
def __init__(self, checkpoint: bool = True, offload : bool = False):
|
|
||||||
|
def __init__(self, checkpoint: bool = True, offload: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self._use_checkpoint = checkpoint
|
self._use_checkpoint = checkpoint
|
||||||
|
@ -78,6 +79,7 @@ def get_tensor_parallel_mode():
|
||||||
|
|
||||||
|
|
||||||
def _ntuple(n):
|
def _ntuple(n):
|
||||||
|
|
||||||
def parse(x):
|
def parse(x):
|
||||||
if isinstance(x, collections.abc.Iterable):
|
if isinstance(x, collections.abc.Iterable):
|
||||||
return x
|
return x
|
||||||
|
|
Loading…
Reference in New Issue