bug fix: pass hook_list to engine (#273)

* bug fix: pass hook_list to engine

* change parameter name
pull/394/head
Jie Zhu 2022-03-02 14:25:52 +08:00 committed by Frank Lee
parent 5a560a060a
commit f867365aba
1 changed files with 4 additions and 1 deletions

View File

@ -27,6 +27,7 @@ from colossalai.utils import (accumulate_gradient, get_current_device,
is_using_ddp, is_using_pp, is_using_sequence, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param) sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
def get_default_parser(): def get_default_parser():
@ -228,6 +229,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None, test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None, lr_scheduler: _LRScheduler = None,
ophooks: List[BaseOpHook] = [],
verbose: bool = True verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: ) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is """Core function to wrap the essential training components with our functionality based on the config which is
@ -412,7 +414,8 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
gradient_handlers=gradient_handlers, gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm clip_grad_norm=clip_grad_norm,
ophook_list=ophooks
) )
return engine, train_dataloader, test_dataloader, lr_scheduler return engine, train_dataloader, test_dataloader, lr_scheduler