mirror of https://github.com/InternLM/InternLM
fix(core): fix demo running error
parent
ad4d13740f
commit
955282dd87
|
@ -141,4 +141,5 @@ small_demo/
|
|||
core.*
|
||||
|
||||
# Run
|
||||
llm_ckpts
|
||||
llm_ckpts
|
||||
events.*
|
|
@ -4,6 +4,7 @@
|
|||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
import inspect
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
import torch.cuda
|
||||
|
@ -12,11 +13,7 @@ import internlm.core.communication as comm
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.common import (
|
||||
get_current_device,
|
||||
move_to_device,
|
||||
switch_virtual_pipeline_parallel_rank,
|
||||
)
|
||||
from internlm.utils.common import get_current_device, move_to_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
||||
|
@ -76,6 +73,16 @@ def pack_return_tensors(return_tensors):
|
|||
return output, label
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_virtual_pipeline_parallel_rank(rank):
|
||||
prev_rank = gpc.virtual_pipeline_parallel_rank
|
||||
try:
|
||||
gpc.set_virtual_pipeline_parallel_rank(rank)
|
||||
yield
|
||||
finally:
|
||||
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
||||
|
||||
|
||||
class PipelineScheduler(BaseScheduler):
|
||||
"""A helper schedule class for pipeline parallelism running environment.
|
||||
It uses non-interleaved 1F1B strategy. Other properties are similar as
|
||||
|
@ -185,8 +192,8 @@ class PipelineScheduler(BaseScheduler):
|
|||
# TODO: remove this after testing new zero with pipeline parallelism
|
||||
model = engine.model
|
||||
dtype = None
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
dtype = torch.half
|
||||
# if isinstance(model, NaiveAMPModel):
|
||||
# dtype = torch.half
|
||||
# TODO 这里需要加入一个操作使得可以支持bf16
|
||||
types = set()
|
||||
for param in model.parameters():
|
||||
|
|
|
@ -8,7 +8,10 @@ from typing import Iterable, Optional
|
|||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.no_pipeline_scheduler import BaseScheduler, NonPipelineScheduler
|
||||
from internlm.core.pipeline_scheduler import PipelineScheduler
|
||||
from internlm.core.pipeline_scheduler import (
|
||||
InterleavedPipelineScheduler,
|
||||
PipelineScheduler,
|
||||
)
|
||||
|
||||
|
||||
class TrainState:
|
||||
|
@ -113,9 +116,8 @@ class Trainer:
|
|||
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
|
||||
self._schedule = schedule
|
||||
|
||||
self.uses_pipeline = isinstance(schedule, PipelineScheduler)
|
||||
if self.uses_pipeline:
|
||||
self._schedule.pre_processing(self)
|
||||
self._schedule.pre_processing(self._engine)
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
|
@ -128,7 +130,7 @@ class Trainer:
|
|||
@property
|
||||
def uses_pipeline(self):
|
||||
"""Returns whether the pipeline parallel is used or not."""
|
||||
return self.uses_pipeline
|
||||
return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler))
|
||||
|
||||
def train(self):
|
||||
self._engine.train()
|
||||
|
|
|
@ -219,11 +219,6 @@ class StaticBatchSampler:
|
|||
assert (
|
||||
batch_size - self.start_bsz
|
||||
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
|
||||
assert (
|
||||
self.start_bsz // micro_bsz >= 4
|
||||
), f"Must have more start samples:`{self.start_bsz}` with micro_bsz:\
|
||||
`{micro_bsz}`, so that the pipeline can run correctly"
|
||||
|
||||
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
|
||||
assert (
|
||||
self.start_bsz % micro_bsz == 0
|
||||
|
|
|
@ -80,6 +80,7 @@ def initialize_trainer(
|
|||
gradient_handlers.append(handler)
|
||||
|
||||
# initialize scheduler for trainer
|
||||
scheduler = None
|
||||
if is_using_pp:
|
||||
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
||||
tensor_shape = get_tensor_shape()
|
||||
|
@ -115,7 +116,7 @@ def initialize_trainer(
|
|||
)
|
||||
|
||||
# if bf16 is used, this value will be wrongly set to fp32, so it needs to be corrected manually
|
||||
if hasattr(gpc.config.model, "dtype"):
|
||||
if hasattr(gpc.config.model, "dtype") and gpc.config.model.dtype == "torch.bfloat16":
|
||||
scheduler.dtype = torch.bfloat16
|
||||
|
||||
trainer = Trainer(engine, scheduler)
|
||||
|
|
|
@ -13,7 +13,6 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
import internlm
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
CURRENT_TIME = None
|
||||
|
||||
|
@ -173,16 +172,6 @@ def conditional_context(context_manager, enable=True):
|
|||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_virtual_pipeline_parallel_rank(rank):
|
||||
prev_rank = gpc.virtual_pipeline_parallel_rank
|
||||
try:
|
||||
gpc.set_virtual_pipeline_parallel_rank(rank)
|
||||
yield
|
||||
finally:
|
||||
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
||||
|
||||
|
||||
class BatchSkipper:
|
||||
"""
|
||||
BatchSkipper is used to determine whether to skip the current batch_idx.
|
||||
|
|
Loading…
Reference in New Issue