mirror of https://github.com/InternLM/InternLM
fix(core): fix demo running error
parent
ad4d13740f
commit
955282dd87
|
@ -141,4 +141,5 @@ small_demo/
|
||||||
core.*
|
core.*
|
||||||
|
|
||||||
# Run
|
# Run
|
||||||
llm_ckpts
|
llm_ckpts
|
||||||
|
events.*
|
|
@ -4,6 +4,7 @@
|
||||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, List, Tuple, Union
|
||||||
|
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
@ -12,11 +13,7 @@ import internlm.core.communication as comm
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.naive_amp import NaiveAMPModel
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
from internlm.utils.common import (
|
from internlm.utils.common import get_current_device, move_to_device
|
||||||
get_current_device,
|
|
||||||
move_to_device,
|
|
||||||
switch_virtual_pipeline_parallel_rank,
|
|
||||||
)
|
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
|
|
||||||
|
@ -76,6 +73,16 @@ def pack_return_tensors(return_tensors):
|
||||||
return output, label
|
return output, label
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def switch_virtual_pipeline_parallel_rank(rank):
|
||||||
|
prev_rank = gpc.virtual_pipeline_parallel_rank
|
||||||
|
try:
|
||||||
|
gpc.set_virtual_pipeline_parallel_rank(rank)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
||||||
|
|
||||||
|
|
||||||
class PipelineScheduler(BaseScheduler):
|
class PipelineScheduler(BaseScheduler):
|
||||||
"""A helper schedule class for pipeline parallelism running environment.
|
"""A helper schedule class for pipeline parallelism running environment.
|
||||||
It uses non-interleaved 1F1B strategy. Other properties are similar as
|
It uses non-interleaved 1F1B strategy. Other properties are similar as
|
||||||
|
@ -185,8 +192,8 @@ class PipelineScheduler(BaseScheduler):
|
||||||
# TODO: remove this after testing new zero with pipeline parallelism
|
# TODO: remove this after testing new zero with pipeline parallelism
|
||||||
model = engine.model
|
model = engine.model
|
||||||
dtype = None
|
dtype = None
|
||||||
if isinstance(model, NaiveAMPModel):
|
# if isinstance(model, NaiveAMPModel):
|
||||||
dtype = torch.half
|
# dtype = torch.half
|
||||||
# TODO 这里需要加入一个操作使得可以支持bf16
|
# TODO 这里需要加入一个操作使得可以支持bf16
|
||||||
types = set()
|
types = set()
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
|
|
|
@ -8,7 +8,10 @@ from typing import Iterable, Optional
|
||||||
|
|
||||||
from internlm.core.engine import Engine
|
from internlm.core.engine import Engine
|
||||||
from internlm.core.no_pipeline_scheduler import BaseScheduler, NonPipelineScheduler
|
from internlm.core.no_pipeline_scheduler import BaseScheduler, NonPipelineScheduler
|
||||||
from internlm.core.pipeline_scheduler import PipelineScheduler
|
from internlm.core.pipeline_scheduler import (
|
||||||
|
InterleavedPipelineScheduler,
|
||||||
|
PipelineScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainState:
|
class TrainState:
|
||||||
|
@ -113,9 +116,8 @@ class Trainer:
|
||||||
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
|
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
|
||||||
self._schedule = schedule
|
self._schedule = schedule
|
||||||
|
|
||||||
self.uses_pipeline = isinstance(schedule, PipelineScheduler)
|
|
||||||
if self.uses_pipeline:
|
if self.uses_pipeline:
|
||||||
self._schedule.pre_processing(self)
|
self._schedule.pre_processing(self._engine)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def engine(self):
|
def engine(self):
|
||||||
|
@ -128,7 +130,7 @@ class Trainer:
|
||||||
@property
|
@property
|
||||||
def uses_pipeline(self):
|
def uses_pipeline(self):
|
||||||
"""Returns whether the pipeline parallel is used or not."""
|
"""Returns whether the pipeline parallel is used or not."""
|
||||||
return self.uses_pipeline
|
return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler))
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
self._engine.train()
|
self._engine.train()
|
||||||
|
|
|
@ -219,11 +219,6 @@ class StaticBatchSampler:
|
||||||
assert (
|
assert (
|
||||||
batch_size - self.start_bsz
|
batch_size - self.start_bsz
|
||||||
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
|
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
|
||||||
assert (
|
|
||||||
self.start_bsz // micro_bsz >= 4
|
|
||||||
), f"Must have more start samples:`{self.start_bsz}` with micro_bsz:\
|
|
||||||
`{micro_bsz}`, so that the pipeline can run correctly"
|
|
||||||
|
|
||||||
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
|
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
|
||||||
assert (
|
assert (
|
||||||
self.start_bsz % micro_bsz == 0
|
self.start_bsz % micro_bsz == 0
|
||||||
|
|
|
@ -80,6 +80,7 @@ def initialize_trainer(
|
||||||
gradient_handlers.append(handler)
|
gradient_handlers.append(handler)
|
||||||
|
|
||||||
# initialize scheduler for trainer
|
# initialize scheduler for trainer
|
||||||
|
scheduler = None
|
||||||
if is_using_pp:
|
if is_using_pp:
|
||||||
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
||||||
tensor_shape = get_tensor_shape()
|
tensor_shape = get_tensor_shape()
|
||||||
|
@ -115,7 +116,7 @@ def initialize_trainer(
|
||||||
)
|
)
|
||||||
|
|
||||||
# if bf16 is used, this value will be wrongly set to fp32, so it needs to be corrected manually
|
# if bf16 is used, this value will be wrongly set to fp32, so it needs to be corrected manually
|
||||||
if hasattr(gpc.config.model, "dtype"):
|
if hasattr(gpc.config.model, "dtype") and gpc.config.model.dtype == "torch.bfloat16":
|
||||||
scheduler.dtype = torch.bfloat16
|
scheduler.dtype = torch.bfloat16
|
||||||
|
|
||||||
trainer = Trainer(engine, scheduler)
|
trainer = Trainer(engine, scheduler)
|
||||||
|
|
|
@ -13,7 +13,6 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import internlm
|
import internlm
|
||||||
from internlm.core.context import global_context as gpc
|
|
||||||
|
|
||||||
CURRENT_TIME = None
|
CURRENT_TIME = None
|
||||||
|
|
||||||
|
@ -173,16 +172,6 @@ def conditional_context(context_manager, enable=True):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def switch_virtual_pipeline_parallel_rank(rank):
|
|
||||||
prev_rank = gpc.virtual_pipeline_parallel_rank
|
|
||||||
try:
|
|
||||||
gpc.set_virtual_pipeline_parallel_rank(rank)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchSkipper:
|
class BatchSkipper:
|
||||||
"""
|
"""
|
||||||
BatchSkipper is used to determine whether to skip the current batch_idx.
|
BatchSkipper is used to determine whether to skip the current batch_idx.
|
||||||
|
|
Loading…
Reference in New Issue