bug fix: pass hook_list to engine (#273)

* bug fix: pass hook_list to engine

* change parameter name
pull/394/head
Jie Zhu 3 years ago committed by Frank Lee
parent 5a560a060a
commit f867365aba

@ -27,6 +27,7 @@ from colossalai.utils import (accumulate_gradient, get_current_device,
is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
def get_default_parser():
@ -228,6 +229,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None,
ophooks: List[BaseOpHook] = [],
verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is
@ -412,7 +414,8 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
optimizer=optimizer,
criterion=criterion,
gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm
clip_grad_norm=clip_grad_norm,
ophook_list=ophooks
)
return engine, train_dataloader, test_dataloader, lr_scheduler

Loading…
Cancel
Save