fix(internlm/utils/parallel.py): fix circular import

pull/293/head
huangting4201 2023-10-08 17:23:29 +08:00
parent 610e011133
commit 1b71b19e23
6 changed files with 30 additions and 31 deletions

View File

@ -328,6 +328,9 @@ class ParallelContext(metaclass=SingletonMeta):
return False return False
return self.is_last_rank(ParallelMode.PIPELINE) return self.is_last_rank(ParallelMode.PIPELINE)
def is_no_pp_or_last_stage(self):
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage()
def get_world_size(self, parallel_mode: ParallelMode): def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`. """Returns the world size for `parallel_mode`.

View File

@ -16,7 +16,7 @@ from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout from internlm.utils.timeout import llm_timeout
# check pacakge # check package
try: try:
import numa import numa
from numa import memory, schedule from numa import memory, schedule
@ -78,10 +78,16 @@ def args_sanity_check():
else: else:
pp = gpc.config.parallel.pipeline.size pp = gpc.config.parallel.pipeline.size
# check fsdp config
if "use_fsdp" not in gpc.config.parallel: if "use_fsdp" not in gpc.config.parallel:
gpc.config.parallel._add_item("use_fsdp", False) gpc.config.parallel._add_item("use_fsdp", False)
assert not (
assert not (gpc.config.parallel.use_fsdp and pp > 1), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP" gpc.config.parallel.use_fsdp and pp > 1
), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
if gpc.config.parallel.use_fsdp:
assert (
torch.__version__ >= "2.0.1"
), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}"
# processing the data config in gpc # processing the data config in gpc
data = gpc.config.data data = gpc.config.data

View File

@ -6,7 +6,6 @@ from torch_scatter import scatter
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.utils.parallel import is_no_pp_or_last_stage
class AccPerplex: class AccPerplex:
@ -138,7 +137,7 @@ class AccPerplex:
self.total_log_probs += total_log_probs self.total_log_probs += total_log_probs
def get_metric(self, reset=True): def get_metric(self, reset=True):
if is_no_pp_or_last_stage() and self.dp_pg is not None: if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
@ -236,7 +235,7 @@ class LossWithTypeId:
self.ds_token_num += token_num_type self.ds_token_num += token_num_type
def get_metric(self, reset=True): def get_metric(self, reset=True):
if is_no_pp_or_last_stage() and self.dp_pg is not None: if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg) torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
if hasattr(self, "total_type_count"): if hasattr(self, "total_type_count"):

View File

@ -48,11 +48,7 @@ from internlm.train.utils import create_param_groups
from internlm.utils.common import DummyProfile from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import ( from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
is_no_pp_or_last_stage,
sync_model_param,
sync_model_param_within_tp,
)
from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.timeout import llm_timeout from internlm.utils.timeout import llm_timeout
@ -85,7 +81,7 @@ def initialize_model():
else: else:
model = NaiveAMPModel( model = NaiveAMPModel(
model=model, model=model,
output_to_fp32=is_no_pp_or_last_stage(), output_to_fp32=gpc.is_no_pp_or_last_stage(),
dtype=gpc.config.model.get("dtype", torch.half), dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False, sync_buffer=False,
) )
@ -113,12 +109,13 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
RMSNorm = try_import_RMSNorm() RMSNorm = try_import_RMSNorm()
if gpc.config.parallel.use_fsdp: if gpc.config.parallel.use_fsdp:
# pre-save info for tensor parallel # pre-save info for tensor parallel
tp_dict = dict() if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for name, param in model.named_parameters(): tp_dict = dict()
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL): for name, param in model.named_parameters():
tp_dict.update({name.replace("model.", ""): True}) if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
else: tp_dict.update({name.replace("model.", ""): True})
tp_dict.update({name.replace("model.", ""): False}) else:
tp_dict.update({name.replace("model.", ""): False})
# set wrap_policy for fsdp wrap # set wrap_policy for fsdp wrap
transformer_wrap_policy = functools.partial( transformer_wrap_policy = functools.partial(
@ -136,13 +133,14 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
forward_prefetch=True, forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True, limit_all_gathers=True,
use_orig_params=True, use_orig_params=False,
) )
# re-set attribute for fsdp module # re-set attribute for fsdp module with tensor parallel
for (name, param), pre in zip(model.named_parameters(), tp_dict): if gpc.get_world_size(ParallelMode.TENSOR) > 1:
if pre in name and tp_dict[pre]: for (name, param), pre in zip(model.named_parameters(), tp_dict):
setattr(param, IS_TENSOR_PARALLEL, True) if pre in name and tp_dict[pre]:
setattr(param, IS_TENSOR_PARALLEL, True)
return model return model
@ -421,7 +419,7 @@ def record_current_batch_training_metrics(
timer.store_last_timers() timer.store_last_timers()
if success_update in (0, True): if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage(): if gpc.is_no_pp_or_last_stage():
acc_perplex = metric.get_metric() acc_perplex = metric.get_metric()
if success_update and gpc.is_rank_for_log(): if success_update and gpc.is_rank_for_log():

View File

@ -51,10 +51,6 @@ def sync_model_param_within_tp(model):
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
def is_no_pp_or_last_stage():
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
def get_parallel_log_file_name(): def get_parallel_log_file_name():
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
fn_prefix = "main_" # Indicates a rank with more output information fn_prefix = "main_" # Indicates a rank with more output information

View File

@ -307,9 +307,6 @@ if __name__ == "__main__":
# initialize distributed environment # initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None assert hasattr(gpc, "config") and gpc.config is not None
if gpc.config.parallel.use_fsdp:
assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}"
# initialize monitor manager context # initialize monitor manager context
with initialize_monitor_manager( with initialize_monitor_manager(