|
|
|
@ -33,14 +33,16 @@ class PipelineSharedModuleWrapper:
|
|
|
|
|
self.ranks_in_group = sub_ranks |
|
|
|
|
|
|
|
|
|
def register_module(self, module: nn.Module): |
|
|
|
|
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' |
|
|
|
|
assert self.ranks_in_group is not None,\ |
|
|
|
|
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' |
|
|
|
|
src = self.ranks_in_group[self.pipeline_ranks[0]] |
|
|
|
|
for p in module.parameters(): |
|
|
|
|
setattr(p, 'pipeline_shared_module_pg', self.group) |
|
|
|
|
dist.broadcast(p, src, group=self.group) |
|
|
|
|
|
|
|
|
|
def register_parameter(self, param: nn.Parameter): |
|
|
|
|
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' |
|
|
|
|
assert self.ranks_in_group is not None,\ |
|
|
|
|
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' |
|
|
|
|
src = self.ranks_in_group[self.pipeline_ranks[0]] |
|
|
|
|
setattr(param, 'pipeline_shared_module_pg', self.group) |
|
|
|
|
dist.broadcast(param, src, group=self.group) |
|
|
|
|