mirror of https://github.com/InternLM/InternLM
feat(*): support sequence_parallel (#180)
* support sequence_parallel for no pipeline * sequence_parallel does not support no-flash-attn * support sequence parallel for pipeline * add memory profiler * Update 13B.py * add memory profiler * fix evaluation bug * remove some unnecessary code * remove some unnecessary code * Update parallel_context.py * modify the config * remove memory profiler * modify the config * support selective dropoutpull/190/head
parent
853becfb6e
commit
c219065348
|
@ -118,6 +118,7 @@ model = dict(
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
use_flash_attn=True,
|
use_flash_attn=True,
|
||||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||||
|
sequence_parallel=False,
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
zero1 parallel:
|
zero1 parallel:
|
||||||
|
|
|
@ -464,7 +464,6 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||||
if self.pipeline_parallel_size > 1:
|
if self.pipeline_parallel_size > 1:
|
||||||
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
||||||
|
|
||||||
for initializer in initializers:
|
for initializer in initializers:
|
||||||
parallel_setting = initializer.init_dist_group()
|
parallel_setting = initializer.init_dist_group()
|
||||||
if isinstance(parallel_setting, list):
|
if isinstance(parallel_setting, list):
|
||||||
|
|
|
@ -30,10 +30,17 @@ def get_tensor_shape():
|
||||||
|
|
||||||
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
||||||
if gpc.config.model.use_flash_attn:
|
if gpc.config.model.use_flash_attn:
|
||||||
tensor_shape = (
|
if gpc.config.model.sequence_parallel:
|
||||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
gpc.config.HIDDEN_SIZE,
|
tensor_shape = (
|
||||||
)
|
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
|
||||||
|
gpc.config.HIDDEN_SIZE,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tensor_shape = (
|
||||||
|
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
||||||
|
gpc.config.HIDDEN_SIZE,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
tensor_shape = (
|
tensor_shape = (
|
||||||
gpc.config.data["micro_bsz"],
|
gpc.config.data["micro_bsz"],
|
||||||
|
@ -132,6 +139,9 @@ class PipelineScheduler(BaseScheduler):
|
||||||
and gpc.is_initialized(ParallelMode.TENSOR)
|
and gpc.is_initialized(ParallelMode.TENSOR)
|
||||||
and gpc.get_world_size(ParallelMode.TENSOR) > 1
|
and gpc.get_world_size(ParallelMode.TENSOR) > 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if gpc.config.model.sequence_parallel:
|
||||||
|
self.scatter_gather_tensors = False
|
||||||
|
|
||||||
# cache for the batch data
|
# cache for the batch data
|
||||||
self.batch_data = None
|
self.batch_data = None
|
||||||
|
@ -254,7 +264,6 @@ class PipelineScheduler(BaseScheduler):
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
self._call_hooks("post_helper_func", output_obj, label)
|
self._call_hooks("post_helper_func", output_obj, label)
|
||||||
|
|
||||||
if return_output_label:
|
if return_output_label:
|
||||||
return_tensors.append((output_obj, label))
|
return_tensors.append((output_obj, label))
|
||||||
if accum_loss is not None:
|
if accum_loss is not None:
|
||||||
|
|
|
@ -38,6 +38,7 @@ def get_default_parser():
|
||||||
parser.add_argument("--local_rank", type=int, help="local rank on the node")
|
parser.add_argument("--local_rank", type=int, help="local rank on the node")
|
||||||
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
||||||
parser.add_argument("--seed", type=int, default=1024)
|
parser.add_argument("--seed", type=int, default=1024)
|
||||||
|
parser.add_argument("--profiling", default=True, action="store_true", help="enable/diable profiling.")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -198,6 +199,10 @@ def args_sanity_check():
|
||||||
# process the model config
|
# process the model config
|
||||||
if "use_flash_attn" not in gpc.config.model:
|
if "use_flash_attn" not in gpc.config.model:
|
||||||
gpc.config.model._add_item("use_flash_attn", True)
|
gpc.config.model._add_item("use_flash_attn", True)
|
||||||
|
if "sequence_parallel" not in gpc.config.model:
|
||||||
|
gpc.config.model._add_item("sequence_parallel", False)
|
||||||
|
else:
|
||||||
|
assert not (gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False), "sequence parallel does not support use_flash_attn=False"
|
||||||
|
|
||||||
|
|
||||||
def launch(
|
def launch(
|
||||||
|
|
|
@ -13,7 +13,7 @@ from torch import Tensor, nn
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
|
||||||
from .utils import gather_forward_split_backward
|
from .utils import gather_forward_split_backward, split_forward_gather_backward
|
||||||
|
|
||||||
|
|
||||||
class Embedding1D(nn.Module):
|
class Embedding1D(nn.Module):
|
||||||
|
@ -55,7 +55,10 @@ class Embedding1D(nn.Module):
|
||||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||||
|
|
||||||
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
||||||
|
|
||||||
|
if gpc.config.model.sequence_parallel:
|
||||||
|
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,6 @@ class ScaleColumnParallelLinear(nn.Linear):
|
||||||
out_features: int,
|
out_features: int,
|
||||||
process_group: Optional[torch.distributed.ProcessGroup],
|
process_group: Optional[torch.distributed.ProcessGroup],
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
sequence_parallel: bool = True,
|
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
weight_scale: int = 1,
|
weight_scale: int = 1,
|
||||||
|
@ -48,7 +47,6 @@ class ScaleColumnParallelLinear(nn.Linear):
|
||||||
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
|
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
|
||||||
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
|
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.sequence_parallel = sequence_parallel
|
|
||||||
self.weight_scale = weight_scale
|
self.weight_scale = weight_scale
|
||||||
|
|
||||||
def forward(self, input): # pylint: disable=W0622
|
def forward(self, input): # pylint: disable=W0622
|
||||||
|
@ -60,7 +58,7 @@ class ScaleColumnParallelLinear(nn.Linear):
|
||||||
else:
|
else:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
return fused_dense_func_torch(
|
return fused_dense_func_torch(
|
||||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
|
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,12 +85,11 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
||||||
out_features: int,
|
out_features: int,
|
||||||
process_group: Optional[torch.distributed.ProcessGroup],
|
process_group: Optional[torch.distributed.ProcessGroup],
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
sequence_parallel: bool = True,
|
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
weight_scale: int = 1,
|
weight_scale: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(in_features, out_features, process_group, bias, sequence_parallel, device, dtype, weight_scale)
|
super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
|
||||||
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
||||||
if bias:
|
if bias:
|
||||||
torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
||||||
|
@ -106,7 +103,7 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
||||||
else:
|
else:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
return fused_dense_func_torch(
|
return fused_dense_func_torch(
|
||||||
input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
|
input, weight, self.bias, process_group=self.process_group, sequence_parallel=gpc.config.model.sequence_parallel
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -168,19 +165,19 @@ class FeedForward(nn.Module):
|
||||||
hidden_features,
|
hidden_features,
|
||||||
process_group,
|
process_group,
|
||||||
bias,
|
bias,
|
||||||
sequence_parallel=False,
|
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.w2 = ColumnParallelLinearTorch(
|
self.w2 = ColumnParallelLinearTorch(
|
||||||
in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype
|
in_features, hidden_features, process_group, bias, sequence_parallel=gpc.config.model.sequence_parallel, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
self.w3 = RowParallelLinearTorch(
|
self.w3 = RowParallelLinearTorch(
|
||||||
hidden_features,
|
hidden_features,
|
||||||
out_features,
|
out_features,
|
||||||
process_group,
|
process_group,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
sequence_parallel=False,
|
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
|
@ -89,7 +89,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
rotary_emb_dim=head_dim,
|
rotary_emb_dim=head_dim,
|
||||||
rotary_emb_scale_base=0,
|
rotary_emb_scale_base=0,
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
sequence_parallel=False,
|
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -121,7 +120,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
bias1=False,
|
bias1=False,
|
||||||
bias2=False,
|
bias2=False,
|
||||||
sequence_parallel=False,
|
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||||
checkpoint_lvl=0,
|
checkpoint_lvl=0,
|
||||||
heuristic="auto",
|
heuristic="auto",
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -300,7 +299,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
max_position_embeddings=-1,
|
max_position_embeddings=-1,
|
||||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
padding_idx=None,
|
padding_idx=None,
|
||||||
sequence_parallel=False,
|
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -342,7 +341,6 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
||||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
bias=False,
|
bias=False,
|
||||||
sequence_parallel=False,
|
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
weight_scale=embed_grad_scale,
|
weight_scale=embed_grad_scale,
|
||||||
|
@ -463,6 +461,7 @@ def build_model_with_cfg(
|
||||||
use_scaled_init: bool = True,
|
use_scaled_init: bool = True,
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
|
sequence_parallel: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Builde model with config
|
Builde model with config
|
||||||
|
|
|
@ -59,7 +59,6 @@ class MHA(nn.Module):
|
||||||
rotary_emb_dim: int = 0,
|
rotary_emb_dim: int = 0,
|
||||||
rotary_emb_scale_base: int = 0,
|
rotary_emb_scale_base: int = 0,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
sequence_parallel: bool = True,
|
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -83,7 +82,7 @@ class MHA(nn.Module):
|
||||||
3 * embed_dim,
|
3 * embed_dim,
|
||||||
process_group,
|
process_group,
|
||||||
bias=True,
|
bias=True,
|
||||||
sequence_parallel=sequence_parallel,
|
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
) # according to https://spaces.ac.cn/archives/9577
|
) # according to https://spaces.ac.cn/archives/9577
|
||||||
|
|
||||||
|
@ -96,7 +95,7 @@ class MHA(nn.Module):
|
||||||
|
|
||||||
# output projection always have the bias (for now)
|
# output projection always have the bias (for now)
|
||||||
self.out_proj = RowParallelLinearTorch(
|
self.out_proj = RowParallelLinearTorch(
|
||||||
embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs
|
embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs
|
||||||
)
|
)
|
||||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
|
|
|
@ -157,6 +157,35 @@ def fused_dense_func_torch(
|
||||||
else:
|
else:
|
||||||
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
||||||
|
|
||||||
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Split the input and keep only the corresponding chuck to the rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_: input matrix.
|
||||||
|
parallel_mode: parallel mode.
|
||||||
|
dim: dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def symbolic(graph, input_):
|
||||||
|
return _split(input_)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, parallel_mode, dim):
|
||||||
|
ctx.mode = parallel_mode
|
||||||
|
ctx.dim = dim
|
||||||
|
return _split(input_, parallel_mode, dim)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
||||||
|
|
||||||
|
|
||||||
|
def split_forward_gather_backward(input_, parallel_mode, dim):
|
||||||
|
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
|
||||||
|
|
||||||
|
|
||||||
def try_import_RMSNorm():
|
def try_import_RMSNorm():
|
||||||
"""
|
"""
|
||||||
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
||||||
|
|
|
@ -90,9 +90,15 @@ def evaluate_on_val_dls(
|
||||||
total_val_bsz = len(batch[1])
|
total_val_bsz = len(batch[1])
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||||
tensor_shape = torch.Size(
|
if gpc.config.model.sequence_parallel:
|
||||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
)
|
tensor_shape = torch.Size(
|
||||||
|
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tensor_shape = torch.Size(
|
||||||
|
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||||
|
)
|
||||||
|
|
||||||
with switch_evaluation_pipeline_scheduler(
|
with switch_evaluation_pipeline_scheduler(
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
|
@ -108,7 +114,7 @@ def evaluate_on_val_dls(
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||||
grad_accum_batch_size = data_cfg.micro_bsz
|
grad_accum_batch_size = data_cfg.micro_bsz
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
with switch_evaluation_no_pipeline_scheduler(
|
with switch_evaluation_no_pipeline_scheduler(
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
grad_accum_size=grad_accum_size,
|
grad_accum_size=grad_accum_size,
|
||||||
|
@ -155,3 +161,13 @@ def evaluate_on_val_dls(
|
||||||
trainer.train()
|
trainer.train()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def switch_sequence_parallel_mode():
|
||||||
|
prev_mode = gpc.config.model.sequence_parallel
|
||||||
|
try:
|
||||||
|
gpc.config.model.sequence_parallel = False
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
gpc.config.model.sequence_parallel = prev_mode
|
19
train.py
19
train.py
|
@ -41,7 +41,7 @@ from internlm.utils.common import (
|
||||||
launch_time,
|
launch_time,
|
||||||
parse_args,
|
parse_args,
|
||||||
)
|
)
|
||||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode
|
||||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.model_checkpoint import (
|
from internlm.utils.model_checkpoint import (
|
||||||
|
@ -618,14 +618,15 @@ def main(args):
|
||||||
|
|
||||||
# evaluate on validation data loaders
|
# evaluate on validation data loaders
|
||||||
if valid_every > 0 and train_state.step_count % valid_every == 0:
|
if valid_every > 0 and train_state.step_count % valid_every == 0:
|
||||||
evaluate_on_val_dls(
|
with switch_sequence_parallel_mode():
|
||||||
trainer=trainer,
|
evaluate_on_val_dls(
|
||||||
val_dls=val_dls,
|
trainer=trainer,
|
||||||
writer=writer,
|
val_dls=val_dls,
|
||||||
logger=logger,
|
writer=writer,
|
||||||
step_count=train_state.step_count,
|
logger=logger,
|
||||||
update_panel=uniscale_logger is not None,
|
step_count=train_state.step_count,
|
||||||
)
|
update_panel=uniscale_logger is not None,
|
||||||
|
)
|
||||||
|
|
||||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||||
# save batch sampler that tracks the true consumed samples
|
# save batch sampler that tracks the true consumed samples
|
||||||
|
|
Loading…
Reference in New Issue