From f867365aba565d54d792d5734159f726c3be05ac Mon Sep 17 00:00:00 2001 From: Jie Zhu Date: Wed, 2 Mar 2022 14:25:52 +0800 Subject: [PATCH] bug fix: pass hook_list to engine (#273) * bug fix: pass hook_list to engine * change parameter name --- colossalai/initialize.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index f68947f4c..d2620c466 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -27,6 +27,7 @@ from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param) from colossalai.zero import convert_to_zero, ShardedOptimizer +from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook def get_default_parser(): @@ -228,6 +229,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]], train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, lr_scheduler: _LRScheduler = None, + ophooks: List[BaseOpHook] = [], verbose: bool = True ) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: """Core function to wrap the essential training components with our functionality based on the config which is @@ -412,7 +414,8 @@ def initialize(model: Union[nn.Module, List[nn.Module]], optimizer=optimizer, criterion=criterion, gradient_handlers=gradient_handlers, - clip_grad_norm=clip_grad_norm + clip_grad_norm=clip_grad_norm, + ophook_list=ophooks ) return engine, train_dataloader, test_dataloader, lr_scheduler