flake8 style (#352)

pull/394/head
Liang Bowen 2022-03-09 17:34:43 +08:00 committed by Frank Lee
parent 54ee8d1254
commit 7eb87f516d
3 changed files with 6 additions and 4 deletions

View File

@ -33,14 +33,16 @@ class PipelineSharedModuleWrapper:
self.ranks_in_group = sub_ranks self.ranks_in_group = sub_ranks
def register_module(self, module: nn.Module): def register_module(self, module: nn.Module):
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
src = self.ranks_in_group[self.pipeline_ranks[0]] src = self.ranks_in_group[self.pipeline_ranks[0]]
for p in module.parameters(): for p in module.parameters():
setattr(p, 'pipeline_shared_module_pg', self.group) setattr(p, 'pipeline_shared_module_pg', self.group)
dist.broadcast(p, src, group=self.group) dist.broadcast(p, src, group=self.group)
def register_parameter(self, param: nn.Parameter): def register_parameter(self, param: nn.Parameter):
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
src = self.ranks_in_group[self.pipeline_ranks[0]] src = self.ranks_in_group[self.pipeline_ranks[0]]
setattr(param, 'pipeline_shared_module_pg', self.group) setattr(param, 'pipeline_shared_module_pg', self.group)
dist.broadcast(param, src, group=self.group) dist.broadcast(param, src, group=self.group)