fix(internlm/utils/parallel.py): fix circular import

pull/293/head
huangting4201 2023-10-08 17:23:29 +08:00
parent 610e011133
commit 1b71b19e23
6 changed files with 30 additions and 31 deletions

View File

@ -328,6 +328,9 @@ class ParallelContext(metaclass=SingletonMeta):
return False
return self.is_last_rank(ParallelMode.PIPELINE)
def is_no_pp_or_last_stage(self):
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage()
def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`.

View File

@ -16,7 +16,7 @@ from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout
# check pacakge
# check package
try:
import numa
from numa import memory, schedule
@ -78,10 +78,16 @@ def args_sanity_check():
else:
pp = gpc.config.parallel.pipeline.size
# check fsdp config
if "use_fsdp" not in gpc.config.parallel:
gpc.config.parallel._add_item("use_fsdp", False)
assert not (gpc.config.parallel.use_fsdp and pp > 1), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
assert not (
gpc.config.parallel.use_fsdp and pp > 1
), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
if gpc.config.parallel.use_fsdp:
assert (
torch.__version__ >= "2.0.1"
), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}"
# processing the data config in gpc
data = gpc.config.data

View File

@ -6,7 +6,6 @@ from torch_scatter import scatter
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.parallel import is_no_pp_or_last_stage
class AccPerplex:
@ -138,7 +137,7 @@ class AccPerplex:
self.total_log_probs += total_log_probs
def get_metric(self, reset=True):
if is_no_pp_or_last_stage() and self.dp_pg is not None:
if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
@ -236,7 +235,7 @@ class LossWithTypeId:
self.ds_token_num += token_num_type
def get_metric(self, reset=True):
if is_no_pp_or_last_stage() and self.dp_pg is not None:
if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
if hasattr(self, "total_type_count"):

View File

@ -48,11 +48,7 @@ from internlm.train.utils import create_param_groups
from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import (
is_no_pp_or_last_stage,
sync_model_param,
sync_model_param_within_tp,
)
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.timeout import llm_timeout
@ -85,7 +81,7 @@ def initialize_model():
else:
model = NaiveAMPModel(
model=model,
output_to_fp32=is_no_pp_or_last_stage(),
output_to_fp32=gpc.is_no_pp_or_last_stage(),
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
@ -113,12 +109,13 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
RMSNorm = try_import_RMSNorm()
if gpc.config.parallel.use_fsdp:
# pre-save info for tensor parallel
tp_dict = dict()
for name, param in model.named_parameters():
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
tp_dict.update({name.replace("model.", ""): True})
else:
tp_dict.update({name.replace("model.", ""): False})
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
tp_dict = dict()
for name, param in model.named_parameters():
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
tp_dict.update({name.replace("model.", ""): True})
else:
tp_dict.update({name.replace("model.", ""): False})
# set wrap_policy for fsdp wrap
transformer_wrap_policy = functools.partial(
@ -136,13 +133,14 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
use_orig_params=True,
use_orig_params=False,
)
# re-set attribute for fsdp module
for (name, param), pre in zip(model.named_parameters(), tp_dict):
if pre in name and tp_dict[pre]:
setattr(param, IS_TENSOR_PARALLEL, True)
# re-set attribute for fsdp module with tensor parallel
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for (name, param), pre in zip(model.named_parameters(), tp_dict):
if pre in name and tp_dict[pre]:
setattr(param, IS_TENSOR_PARALLEL, True)
return model
@ -421,7 +419,7 @@ def record_current_batch_training_metrics(
timer.store_last_timers()
if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage():
if gpc.is_no_pp_or_last_stage():
acc_perplex = metric.get_metric()
if success_update and gpc.is_rank_for_log():

View File

@ -51,10 +51,6 @@ def sync_model_param_within_tp(model):
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
def is_no_pp_or_last_stage():
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
def get_parallel_log_file_name():
if gpc.is_rank_for_log():
fn_prefix = "main_" # Indicates a rank with more output information

View File

@ -307,9 +307,6 @@ if __name__ == "__main__":
# initialize distributed environment
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
if gpc.config.parallel.use_fsdp:
assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}"
# initialize monitor manager context
with initialize_monitor_manager(