[fp8] fix linear hook (#6046)

pull/6045/head
Hongxin Liu 2024-09-03 16:37:16 +08:00 committed by GitHub
parent c3b5caff0e
commit 26e553937b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -119,7 +119,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather:
self.op_hook = ZeroOpHook()
self.op_hooks.append(ZeroOpHook())
if use_fp8 or overlap_allgather:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter