mirror of https://github.com/hpcaitech/ColossalAI
flake8 style (#352)
parent
54ee8d1254
commit
7eb87f516d
|
@ -38,7 +38,7 @@ class CheckpointModule(nn.Module):
|
|||
|
||||
def divide(numerator, denominator):
|
||||
"""Only allow exact division
|
||||
|
||||
|
||||
:param numerator: Numerator of the division
|
||||
:param denominator: Denominator of the division
|
||||
"""
|
||||
|
|
|
@ -101,7 +101,7 @@ class WrappedDropPath(nn.Module):
|
|||
|
||||
@LAYERS.register_module
|
||||
class VanillaPatchEmbedding(nn.Module):
|
||||
"""
|
||||
"""
|
||||
2D Image to Patch Embedding
|
||||
|
||||
:param img_size: image size
|
||||
|
|
|
@ -33,14 +33,16 @@ class PipelineSharedModuleWrapper:
|
|||
self.ranks_in_group = sub_ranks
|
||||
|
||||
def register_module(self, module: nn.Module):
|
||||
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
assert self.ranks_in_group is not None,\
|
||||
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||
for p in module.parameters():
|
||||
setattr(p, 'pipeline_shared_module_pg', self.group)
|
||||
dist.broadcast(p, src, group=self.group)
|
||||
|
||||
def register_parameter(self, param: nn.Parameter):
|
||||
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
assert self.ranks_in_group is not None,\
|
||||
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||
setattr(param, 'pipeline_shared_module_pg', self.group)
|
||||
dist.broadcast(param, src, group=self.group)
|
||||
|
|
Loading…
Reference in New Issue