Browse Source

flake8 style (#352)

pull/394/head
Liang Bowen 3 years ago committed by Frank Lee
parent
commit
7eb87f516d
  1. 2
      colossalai/nn/layer/utils/common.py
  2. 2
      colossalai/nn/layer/vanilla/layers.py
  3. 6
      colossalai/nn/layer/wrapper/pipeline_wrapper.py

2
colossalai/nn/layer/utils/common.py

@ -38,7 +38,7 @@ class CheckpointModule(nn.Module):
def divide(numerator, denominator): def divide(numerator, denominator):
"""Only allow exact division """Only allow exact division
:param numerator: Numerator of the division :param numerator: Numerator of the division
:param denominator: Denominator of the division :param denominator: Denominator of the division
""" """

2
colossalai/nn/layer/vanilla/layers.py

@ -101,7 +101,7 @@ class WrappedDropPath(nn.Module):
@LAYERS.register_module @LAYERS.register_module
class VanillaPatchEmbedding(nn.Module): class VanillaPatchEmbedding(nn.Module):
""" """
2D Image to Patch Embedding 2D Image to Patch Embedding
:param img_size: image size :param img_size: image size

6
colossalai/nn/layer/wrapper/pipeline_wrapper.py

@ -33,14 +33,16 @@ class PipelineSharedModuleWrapper:
self.ranks_in_group = sub_ranks self.ranks_in_group = sub_ranks
def register_module(self, module: nn.Module): def register_module(self, module: nn.Module):
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
src = self.ranks_in_group[self.pipeline_ranks[0]] src = self.ranks_in_group[self.pipeline_ranks[0]]
for p in module.parameters(): for p in module.parameters():
setattr(p, 'pipeline_shared_module_pg', self.group) setattr(p, 'pipeline_shared_module_pg', self.group)
dist.broadcast(p, src, group=self.group) dist.broadcast(p, src, group=self.group)
def register_parameter(self, param: nn.Parameter): def register_parameter(self, param: nn.Parameter):
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
src = self.ranks_in_group[self.pipeline_ranks[0]] src = self.ranks_in_group[self.pipeline_ranks[0]]
setattr(param, 'pipeline_shared_module_pg', self.group) setattr(param, 'pipeline_shared_module_pg', self.group)
dist.broadcast(param, src, group=self.group) dist.broadcast(param, src, group=self.group)

Loading…
Cancel
Save