mirror of https://github.com/hpcaitech/ColossalAI
16 lines
589 B
Python
16 lines
589 B
Python
from colossalai.registry import HOOKS
|
|
from colossalai.trainer import BaseHook
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.context import ParallelMode
|
|
|
|
|
|
@HOOKS.register_module
|
|
class TotalBatchsizeHook(BaseHook):
|
|
def __init__(self, trainer, priority: int = 2) -> None:
|
|
super().__init__(trainer, priority)
|
|
|
|
def before_train(self):
|
|
total_batch_size = gpc.config.BATCH_SIZE * \
|
|
gpc.config.engine.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA)
|
|
self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0])
|