flake8 style (#352)

pull/394/head
Liang Bowen 2022-03-09 17:34:43 +08:00 committed by Frank Lee
parent 54ee8d1254
commit 7eb87f516d
3 changed files with 6 additions and 4 deletions

View File

@ -38,7 +38,7 @@ class CheckpointModule(nn.Module):
def divide(numerator, denominator):
"""Only allow exact division
:param numerator: Numerator of the division
:param denominator: Denominator of the division
"""

View File

@ -101,7 +101,7 @@ class WrappedDropPath(nn.Module):
@LAYERS.register_module
class VanillaPatchEmbedding(nn.Module):
"""
"""
2D Image to Patch Embedding
:param img_size: image size

View File

@ -33,14 +33,16 @@ class PipelineSharedModuleWrapper:
self.ranks_in_group = sub_ranks
def register_module(self, module: nn.Module):
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
src = self.ranks_in_group[self.pipeline_ranks[0]]
for p in module.parameters():
setattr(p, 'pipeline_shared_module_pg', self.group)
dist.broadcast(p, src, group=self.group)
def register_parameter(self, param: nn.Parameter):
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
src = self.ranks_in_group[self.pipeline_ranks[0]]
setattr(param, 'pipeline_shared_module_pg', self.group)
dist.broadcast(param, src, group=self.group)