mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
42 lines
1.4 KiB
42 lines
1.4 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import torch.distributed as dist
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
|
def print_rank_0(msg: str, logger=None):
|
|
'''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
|
|
|
|
:param msg: A str message to output
|
|
:param logger: python logger object, defaults to None
|
|
'''
|
|
if gpc.get_global_rank() == 0:
|
|
if logger is None:
|
|
print(msg, flush=True)
|
|
else:
|
|
logger.info(msg)
|
|
# print(msg, flush=True)
|
|
|
|
|
|
def sync_model_param_in_dp(model):
|
|
'''Make sure data parameters are consistent during Data Parallel Mode
|
|
|
|
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
|
'''
|
|
|
|
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
|
for param in model.parameters():
|
|
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
|
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
|
|
|
def is_dp_rank_0():
|
|
return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA)
|
|
|
|
def is_tp_rank_0():
|
|
return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR)
|
|
|
|
def is_no_pp_or_last_stage():
|
|
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) |