fix(core): fix demo running error

pull/67/head
黄婷 2023-07-13 16:44:58 +08:00
parent ad4d13740f
commit 955282dd87
6 changed files with 24 additions and 29 deletions

3
.gitignore vendored
View File

@ -141,4 +141,5 @@ small_demo/
core.*
# Run
llm_ckpts
llm_ckpts
events.*

View File

@ -4,6 +4,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
import inspect
from contextlib import contextmanager
from typing import Callable, List, Tuple, Union
import torch.cuda
@ -12,11 +13,7 @@ import internlm.core.communication as comm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.common import (
get_current_device,
move_to_device,
switch_virtual_pipeline_parallel_rank,
)
from internlm.utils.common import get_current_device, move_to_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
@ -76,6 +73,16 @@ def pack_return_tensors(return_tensors):
return output, label
@contextmanager
def switch_virtual_pipeline_parallel_rank(rank):
prev_rank = gpc.virtual_pipeline_parallel_rank
try:
gpc.set_virtual_pipeline_parallel_rank(rank)
yield
finally:
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
class PipelineScheduler(BaseScheduler):
"""A helper schedule class for pipeline parallelism running environment.
It uses non-interleaved 1F1B strategy. Other properties are similar as
@ -185,8 +192,8 @@ class PipelineScheduler(BaseScheduler):
# TODO: remove this after testing new zero with pipeline parallelism
model = engine.model
dtype = None
if isinstance(model, NaiveAMPModel):
dtype = torch.half
# if isinstance(model, NaiveAMPModel):
# dtype = torch.half
# TODO 这里需要加入一个操作使得可以支持bf16
types = set()
for param in model.parameters():

View File

@ -8,7 +8,10 @@ from typing import Iterable, Optional
from internlm.core.engine import Engine
from internlm.core.no_pipeline_scheduler import BaseScheduler, NonPipelineScheduler
from internlm.core.pipeline_scheduler import PipelineScheduler
from internlm.core.pipeline_scheduler import (
InterleavedPipelineScheduler,
PipelineScheduler,
)
class TrainState:
@ -113,9 +116,8 @@ class Trainer:
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
self._schedule = schedule
self.uses_pipeline = isinstance(schedule, PipelineScheduler)
if self.uses_pipeline:
self._schedule.pre_processing(self)
self._schedule.pre_processing(self._engine)
@property
def engine(self):
@ -128,7 +130,7 @@ class Trainer:
@property
def uses_pipeline(self):
"""Returns whether the pipeline parallel is used or not."""
return self.uses_pipeline
return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler))
def train(self):
self._engine.train()

View File

@ -219,11 +219,6 @@ class StaticBatchSampler:
assert (
batch_size - self.start_bsz
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
assert (
self.start_bsz // micro_bsz >= 4
), f"Must have more start samples:`{self.start_bsz}` with micro_bsz:\
`{micro_bsz}`, so that the pipeline can run correctly"
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
assert (
self.start_bsz % micro_bsz == 0

View File

@ -80,6 +80,7 @@ def initialize_trainer(
gradient_handlers.append(handler)
# initialize scheduler for trainer
scheduler = None
if is_using_pp:
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
tensor_shape = get_tensor_shape()
@ -115,7 +116,7 @@ def initialize_trainer(
)
# if bf16 is used, this value will be wrongly set to fp32, so it needs to be corrected manually
if hasattr(gpc.config.model, "dtype"):
if hasattr(gpc.config.model, "dtype") and gpc.config.model.dtype == "torch.bfloat16":
scheduler.dtype = torch.bfloat16
trainer = Trainer(engine, scheduler)

View File

@ -13,7 +13,6 @@ import numpy as np
import torch
import internlm
from internlm.core.context import global_context as gpc
CURRENT_TIME = None
@ -173,16 +172,6 @@ def conditional_context(context_manager, enable=True):
yield
@contextmanager
def switch_virtual_pipeline_parallel_rank(rank):
prev_rank = gpc.virtual_pipeline_parallel_rank
try:
gpc.set_virtual_pipeline_parallel_rank(rank)
yield
finally:
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
class BatchSkipper:
"""
BatchSkipper is used to determine whether to skip the current batch_idx.