mirror of https://github.com/hpcaitech/ColossalAI
flake8 style (#352)
parent
54ee8d1254
commit
7eb87f516d
|
@ -33,14 +33,16 @@ class PipelineSharedModuleWrapper:
|
||||||
self.ranks_in_group = sub_ranks
|
self.ranks_in_group = sub_ranks
|
||||||
|
|
||||||
def register_module(self, module: nn.Module):
|
def register_module(self, module: nn.Module):
|
||||||
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
assert self.ranks_in_group is not None,\
|
||||||
|
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
setattr(p, 'pipeline_shared_module_pg', self.group)
|
setattr(p, 'pipeline_shared_module_pg', self.group)
|
||||||
dist.broadcast(p, src, group=self.group)
|
dist.broadcast(p, src, group=self.group)
|
||||||
|
|
||||||
def register_parameter(self, param: nn.Parameter):
|
def register_parameter(self, param: nn.Parameter):
|
||||||
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
assert self.ranks_in_group is not None,\
|
||||||
|
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||||
setattr(param, 'pipeline_shared_module_pg', self.group)
|
setattr(param, 'pipeline_shared_module_pg', self.group)
|
||||||
dist.broadcast(param, src, group=self.group)
|
dist.broadcast(param, src, group=self.group)
|
||||||
|
|
Loading…
Reference in New Issue