ColossalAI/examples/vit-b16/hooks.py

16 lines
589 B
Python

from colossalai.registry import HOOKS
from colossalai.trainer import BaseHook
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
@HOOKS.register_module
class TotalBatchsizeHook(BaseHook):
def __init__(self, trainer, priority: int = 2) -> None:
super().__init__(trainer, priority)
def before_train(self):
total_batch_size = gpc.config.BATCH_SIZE * \
gpc.config.engine.gradient_accumulation * gpc.get_world_size(ParallelMode.DATA)
self.logger.info(f'Total batch size = {total_batch_size}', ranks=[0])