mirror of https://github.com/InternLM/InternLM
fix(internlm/utils/parallel.py): fix circular import
parent
610e011133
commit
1b71b19e23
|
@ -328,6 +328,9 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
return False
|
||||
return self.is_last_rank(ParallelMode.PIPELINE)
|
||||
|
||||
def is_no_pp_or_last_stage(self):
|
||||
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage()
|
||||
|
||||
def get_world_size(self, parallel_mode: ParallelMode):
|
||||
"""Returns the world size for `parallel_mode`.
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from internlm.utils.common import get_master_node
|
|||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
# check pacakge
|
||||
# check package
|
||||
try:
|
||||
import numa
|
||||
from numa import memory, schedule
|
||||
|
@ -78,10 +78,16 @@ def args_sanity_check():
|
|||
else:
|
||||
pp = gpc.config.parallel.pipeline.size
|
||||
|
||||
# check fsdp config
|
||||
if "use_fsdp" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("use_fsdp", False)
|
||||
|
||||
assert not (gpc.config.parallel.use_fsdp and pp > 1), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
|
||||
assert not (
|
||||
gpc.config.parallel.use_fsdp and pp > 1
|
||||
), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
assert (
|
||||
torch.__version__ >= "2.0.1"
|
||||
), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}"
|
||||
|
||||
# processing the data config in gpc
|
||||
data = gpc.config.data
|
||||
|
|
|
@ -6,7 +6,6 @@ from torch_scatter import scatter
|
|||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.parallel import is_no_pp_or_last_stage
|
||||
|
||||
|
||||
class AccPerplex:
|
||||
|
@ -138,7 +137,7 @@ class AccPerplex:
|
|||
self.total_log_probs += total_log_probs
|
||||
|
||||
def get_metric(self, reset=True):
|
||||
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||
if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
|
@ -236,7 +235,7 @@ class LossWithTypeId:
|
|||
self.ds_token_num += token_num_type
|
||||
|
||||
def get_metric(self, reset=True):
|
||||
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||
if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||
if hasattr(self, "total_type_count"):
|
||||
|
|
|
@ -48,11 +48,7 @@ from internlm.train.utils import create_param_groups
|
|||
from internlm.utils.common import DummyProfile
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import (
|
||||
is_no_pp_or_last_stage,
|
||||
sync_model_param,
|
||||
sync_model_param_within_tp,
|
||||
)
|
||||
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
|
@ -85,7 +81,7 @@ def initialize_model():
|
|||
else:
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=is_no_pp_or_last_stage(),
|
||||
output_to_fp32=gpc.is_no_pp_or_last_stage(),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
)
|
||||
|
@ -113,12 +109,13 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
|||
RMSNorm = try_import_RMSNorm()
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
# pre-save info for tensor parallel
|
||||
tp_dict = dict()
|
||||
for name, param in model.named_parameters():
|
||||
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
|
||||
tp_dict.update({name.replace("model.", ""): True})
|
||||
else:
|
||||
tp_dict.update({name.replace("model.", ""): False})
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
tp_dict = dict()
|
||||
for name, param in model.named_parameters():
|
||||
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
|
||||
tp_dict.update({name.replace("model.", ""): True})
|
||||
else:
|
||||
tp_dict.update({name.replace("model.", ""): False})
|
||||
|
||||
# set wrap_policy for fsdp wrap
|
||||
transformer_wrap_policy = functools.partial(
|
||||
|
@ -136,13 +133,14 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
|||
forward_prefetch=True,
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
use_orig_params=False,
|
||||
)
|
||||
|
||||
# re-set attribute for fsdp module
|
||||
for (name, param), pre in zip(model.named_parameters(), tp_dict):
|
||||
if pre in name and tp_dict[pre]:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
# re-set attribute for fsdp module with tensor parallel
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
for (name, param), pre in zip(model.named_parameters(), tp_dict):
|
||||
if pre in name and tp_dict[pre]:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -421,7 +419,7 @@ def record_current_batch_training_metrics(
|
|||
timer.store_last_timers()
|
||||
if success_update in (0, True):
|
||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||
if is_no_pp_or_last_stage():
|
||||
if gpc.is_no_pp_or_last_stage():
|
||||
acc_perplex = metric.get_metric()
|
||||
|
||||
if success_update and gpc.is_rank_for_log():
|
||||
|
|
|
@ -51,10 +51,6 @@ def sync_model_param_within_tp(model):
|
|||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
|
||||
|
||||
def is_no_pp_or_last_stage():
|
||||
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
|
||||
|
||||
|
||||
def get_parallel_log_file_name():
|
||||
if gpc.is_rank_for_log():
|
||||
fn_prefix = "main_" # Indicates a rank with more output information
|
||||
|
|
3
train.py
3
train.py
|
@ -307,9 +307,6 @@ if __name__ == "__main__":
|
|||
# initialize distributed environment
|
||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||
assert hasattr(gpc, "config") and gpc.config is not None
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}"
|
||||
|
||||
|
||||
# initialize monitor manager context
|
||||
with initialize_monitor_manager(
|
||||
|
|
Loading…
Reference in New Issue