mirror of https://github.com/hpcaitech/ColossalAI
bug fix: pass hook_list to engine (#273)
* bug fix: pass hook_list to engine * change parameter namepull/394/head
parent
5a560a060a
commit
f867365aba
|
@ -27,6 +27,7 @@ from colossalai.utils import (accumulate_gradient, get_current_device,
|
||||||
is_using_ddp, is_using_pp, is_using_sequence,
|
is_using_ddp, is_using_pp, is_using_sequence,
|
||||||
sync_model_param)
|
sync_model_param)
|
||||||
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
||||||
|
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||||
|
|
||||||
|
|
||||||
def get_default_parser():
|
def get_default_parser():
|
||||||
|
@ -228,6 +229,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||||
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||||
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||||
lr_scheduler: _LRScheduler = None,
|
lr_scheduler: _LRScheduler = None,
|
||||||
|
ophooks: List[BaseOpHook] = [],
|
||||||
verbose: bool = True
|
verbose: bool = True
|
||||||
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
||||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||||
|
@ -412,7 +414,8 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
gradient_handlers=gradient_handlers,
|
gradient_handlers=gradient_handlers,
|
||||||
clip_grad_norm=clip_grad_norm
|
clip_grad_norm=clip_grad_norm,
|
||||||
|
ophook_list=ophooks
|
||||||
)
|
)
|
||||||
|
|
||||||
return engine, train_dataloader, test_dataloader, lr_scheduler
|
return engine, train_dataloader, test_dataloader, lr_scheduler
|
||||||
|
|
Loading…
Reference in New Issue