diff --git a/.gitignore b/.gitignore index 1ed2455..44b0e77 100644 --- a/.gitignore +++ b/.gitignore @@ -141,4 +141,5 @@ small_demo/ core.* # Run -llm_ckpts \ No newline at end of file +llm_ckpts +events.* \ No newline at end of file diff --git a/internlm/core/pipeline_scheduler.py b/internlm/core/pipeline_scheduler.py index b6b1d80..4474395 100644 --- a/internlm/core/pipeline_scheduler.py +++ b/internlm/core/pipeline_scheduler.py @@ -4,6 +4,7 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine import inspect +from contextlib import contextmanager from typing import Callable, List, Tuple, Union import torch.cuda @@ -12,11 +13,7 @@ import internlm.core.communication as comm from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel -from internlm.utils.common import ( - get_current_device, - move_to_device, - switch_virtual_pipeline_parallel_rank, -) +from internlm.utils.common import get_current_device, move_to_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer @@ -76,6 +73,16 @@ def pack_return_tensors(return_tensors): return output, label +@contextmanager +def switch_virtual_pipeline_parallel_rank(rank): + prev_rank = gpc.virtual_pipeline_parallel_rank + try: + gpc.set_virtual_pipeline_parallel_rank(rank) + yield + finally: + gpc.set_virtual_pipeline_parallel_rank(prev_rank) + + class PipelineScheduler(BaseScheduler): """A helper schedule class for pipeline parallelism running environment. It uses non-interleaved 1F1B strategy. Other properties are similar as @@ -185,8 +192,8 @@ class PipelineScheduler(BaseScheduler): # TODO: remove this after testing new zero with pipeline parallelism model = engine.model dtype = None - if isinstance(model, NaiveAMPModel): - dtype = torch.half + # if isinstance(model, NaiveAMPModel): + # dtype = torch.half # TODO 这里需要加入一个操作使得可以支持bf16 types = set() for param in model.parameters(): diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index ad83c94..2215751 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -8,7 +8,10 @@ from typing import Iterable, Optional from internlm.core.engine import Engine from internlm.core.no_pipeline_scheduler import BaseScheduler, NonPipelineScheduler -from internlm.core.pipeline_scheduler import PipelineScheduler +from internlm.core.pipeline_scheduler import ( + InterleavedPipelineScheduler, + PipelineScheduler, +) class TrainState: @@ -113,9 +116,8 @@ class Trainer: ), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}" self._schedule = schedule - self.uses_pipeline = isinstance(schedule, PipelineScheduler) if self.uses_pipeline: - self._schedule.pre_processing(self) + self._schedule.pre_processing(self._engine) @property def engine(self): @@ -128,7 +130,7 @@ class Trainer: @property def uses_pipeline(self): """Returns whether the pipeline parallel is used or not.""" - return self.uses_pipeline + return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler)) def train(self): self._engine.train() diff --git a/internlm/data/batch_sampler.py b/internlm/data/batch_sampler.py index 1ee4126..16fd6fc 100644 --- a/internlm/data/batch_sampler.py +++ b/internlm/data/batch_sampler.py @@ -219,11 +219,6 @@ class StaticBatchSampler: assert ( batch_size - self.start_bsz ) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}" - assert ( - self.start_bsz // micro_bsz >= 4 - ), f"Must have more start samples:`{self.start_bsz}` with micro_bsz:\ - `{micro_bsz}`, so that the pipeline can run correctly" - assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})" assert ( self.start_bsz % micro_bsz == 0 diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index c4c635e..0df33dd 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -80,6 +80,7 @@ def initialize_trainer( gradient_handlers.append(handler) # initialize scheduler for trainer + scheduler = None if is_using_pp: gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num tensor_shape = get_tensor_shape() @@ -115,7 +116,7 @@ def initialize_trainer( ) # if bf16 is used, this value will be wrongly set to fp32, so it needs to be corrected manually - if hasattr(gpc.config.model, "dtype"): + if hasattr(gpc.config.model, "dtype") and gpc.config.model.dtype == "torch.bfloat16": scheduler.dtype = torch.bfloat16 trainer = Trainer(engine, scheduler) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 7edaa92..7c069ee 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -13,7 +13,6 @@ import numpy as np import torch import internlm -from internlm.core.context import global_context as gpc CURRENT_TIME = None @@ -173,16 +172,6 @@ def conditional_context(context_manager, enable=True): yield -@contextmanager -def switch_virtual_pipeline_parallel_rank(rank): - prev_rank = gpc.virtual_pipeline_parallel_rank - try: - gpc.set_virtual_pipeline_parallel_rank(rank) - yield - finally: - gpc.set_virtual_pipeline_parallel_rank(prev_rank) - - class BatchSkipper: """ BatchSkipper is used to determine whether to skip the current batch_idx.