mirror of https://github.com/InternLM/InternLM
fix(internlm/utils/parallel.py): fix circular import
parent
610e011133
commit
1b71b19e23
|
@ -328,6 +328,9 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
return False
|
return False
|
||||||
return self.is_last_rank(ParallelMode.PIPELINE)
|
return self.is_last_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
def is_no_pp_or_last_stage(self):
|
||||||
|
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage()
|
||||||
|
|
||||||
def get_world_size(self, parallel_mode: ParallelMode):
|
def get_world_size(self, parallel_mode: ParallelMode):
|
||||||
"""Returns the world size for `parallel_mode`.
|
"""Returns the world size for `parallel_mode`.
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from internlm.utils.common import get_master_node
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
# check pacakge
|
# check package
|
||||||
try:
|
try:
|
||||||
import numa
|
import numa
|
||||||
from numa import memory, schedule
|
from numa import memory, schedule
|
||||||
|
@ -78,10 +78,16 @@ def args_sanity_check():
|
||||||
else:
|
else:
|
||||||
pp = gpc.config.parallel.pipeline.size
|
pp = gpc.config.parallel.pipeline.size
|
||||||
|
|
||||||
|
# check fsdp config
|
||||||
if "use_fsdp" not in gpc.config.parallel:
|
if "use_fsdp" not in gpc.config.parallel:
|
||||||
gpc.config.parallel._add_item("use_fsdp", False)
|
gpc.config.parallel._add_item("use_fsdp", False)
|
||||||
|
assert not (
|
||||||
assert not (gpc.config.parallel.use_fsdp and pp > 1), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
|
gpc.config.parallel.use_fsdp and pp > 1
|
||||||
|
), "FSDP not support when pipeline size > 1, please set pipeline size to 1 or close FSDP"
|
||||||
|
if gpc.config.parallel.use_fsdp:
|
||||||
|
assert (
|
||||||
|
torch.__version__ >= "2.0.1"
|
||||||
|
), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}"
|
||||||
|
|
||||||
# processing the data config in gpc
|
# processing the data config in gpc
|
||||||
data = gpc.config.data
|
data = gpc.config.data
|
||||||
|
|
|
@ -6,7 +6,6 @@ from torch_scatter import scatter
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.utils.parallel import is_no_pp_or_last_stage
|
|
||||||
|
|
||||||
|
|
||||||
class AccPerplex:
|
class AccPerplex:
|
||||||
|
@ -138,7 +137,7 @@ class AccPerplex:
|
||||||
self.total_log_probs += total_log_probs
|
self.total_log_probs += total_log_probs
|
||||||
|
|
||||||
def get_metric(self, reset=True):
|
def get_metric(self, reset=True):
|
||||||
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||||
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
@ -236,7 +235,7 @@ class LossWithTypeId:
|
||||||
self.ds_token_num += token_num_type
|
self.ds_token_num += token_num_type
|
||||||
|
|
||||||
def get_metric(self, reset=True):
|
def get_metric(self, reset=True):
|
||||||
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
if gpc.is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||||
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
if hasattr(self, "total_type_count"):
|
if hasattr(self, "total_type_count"):
|
||||||
|
|
|
@ -48,11 +48,7 @@ from internlm.train.utils import create_param_groups
|
||||||
from internlm.utils.common import DummyProfile
|
from internlm.utils.common import DummyProfile
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.parallel import (
|
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
|
||||||
is_no_pp_or_last_stage,
|
|
||||||
sync_model_param,
|
|
||||||
sync_model_param_within_tp,
|
|
||||||
)
|
|
||||||
from internlm.utils.registry import MODEL_INITIALIZER
|
from internlm.utils.registry import MODEL_INITIALIZER
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
|
@ -85,7 +81,7 @@ def initialize_model():
|
||||||
else:
|
else:
|
||||||
model = NaiveAMPModel(
|
model = NaiveAMPModel(
|
||||||
model=model,
|
model=model,
|
||||||
output_to_fp32=is_no_pp_or_last_stage(),
|
output_to_fp32=gpc.is_no_pp_or_last_stage(),
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
sync_buffer=False,
|
sync_buffer=False,
|
||||||
)
|
)
|
||||||
|
@ -113,12 +109,13 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
RMSNorm = try_import_RMSNorm()
|
RMSNorm = try_import_RMSNorm()
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
# pre-save info for tensor parallel
|
# pre-save info for tensor parallel
|
||||||
tp_dict = dict()
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
for name, param in model.named_parameters():
|
tp_dict = dict()
|
||||||
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
|
for name, param in model.named_parameters():
|
||||||
tp_dict.update({name.replace("model.", ""): True})
|
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
|
||||||
else:
|
tp_dict.update({name.replace("model.", ""): True})
|
||||||
tp_dict.update({name.replace("model.", ""): False})
|
else:
|
||||||
|
tp_dict.update({name.replace("model.", ""): False})
|
||||||
|
|
||||||
# set wrap_policy for fsdp wrap
|
# set wrap_policy for fsdp wrap
|
||||||
transformer_wrap_policy = functools.partial(
|
transformer_wrap_policy = functools.partial(
|
||||||
|
@ -136,13 +133,14 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
forward_prefetch=True,
|
forward_prefetch=True,
|
||||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||||
limit_all_gathers=True,
|
limit_all_gathers=True,
|
||||||
use_orig_params=True,
|
use_orig_params=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# re-set attribute for fsdp module
|
# re-set attribute for fsdp module with tensor parallel
|
||||||
for (name, param), pre in zip(model.named_parameters(), tp_dict):
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
if pre in name and tp_dict[pre]:
|
for (name, param), pre in zip(model.named_parameters(), tp_dict):
|
||||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
if pre in name and tp_dict[pre]:
|
||||||
|
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -421,7 +419,7 @@ def record_current_batch_training_metrics(
|
||||||
timer.store_last_timers()
|
timer.store_last_timers()
|
||||||
if success_update in (0, True):
|
if success_update in (0, True):
|
||||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||||
if is_no_pp_or_last_stage():
|
if gpc.is_no_pp_or_last_stage():
|
||||||
acc_perplex = metric.get_metric()
|
acc_perplex = metric.get_metric()
|
||||||
|
|
||||||
if success_update and gpc.is_rank_for_log():
|
if success_update and gpc.is_rank_for_log():
|
||||||
|
|
|
@ -51,10 +51,6 @@ def sync_model_param_within_tp(model):
|
||||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||||
|
|
||||||
|
|
||||||
def is_no_pp_or_last_stage():
|
|
||||||
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parallel_log_file_name():
|
def get_parallel_log_file_name():
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
fn_prefix = "main_" # Indicates a rank with more output information
|
fn_prefix = "main_" # Indicates a rank with more output information
|
||||||
|
|
3
train.py
3
train.py
|
@ -307,9 +307,6 @@ if __name__ == "__main__":
|
||||||
# initialize distributed environment
|
# initialize distributed environment
|
||||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||||
assert hasattr(gpc, "config") and gpc.config is not None
|
assert hasattr(gpc, "config") and gpc.config is not None
|
||||||
if gpc.config.parallel.use_fsdp:
|
|
||||||
assert torch.__version__ >= "2.0.1", f"requires torch>=2.0.1 but current version is {torch.__version__}"
|
|
||||||
|
|
||||||
|
|
||||||
# initialize monitor manager context
|
# initialize monitor manager context
|
||||||
with initialize_monitor_manager(
|
with initialize_monitor_manager(
|
||||||
|
|
Loading…
Reference in New Issue