mirror of https://github.com/InternLM/InternLM
refactor(scheduler): rewrite pipeline scheduler (#138)
* refactor(scheduler): rewrite pipeline scheduler * fix(*): fix pipeline scheduler bugs * fix(*): fix merge bug * feat(*): update codes with todo tag * feat(*): add comments * feat(internlm/core/scheduler): update recv_prev/next logic * feat(utils/evaluation.py): update sche metric hook for valid --------- Co-authored-by: huangting.p <huangting@sensetime.com>pull/155/head
parent
d67be17f96
commit
0268d8eda1
|
@ -117,6 +117,7 @@ model = dict(
|
|||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
||||
|
@ -125,12 +126,14 @@ zero1 parallel:
|
|||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||
pipeline parallel: pipeline parallel size.
|
||||
pipeline parallel (dict):
|
||||
1. size: int, the size of pipeline parallel.
|
||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=8,
|
||||
pipeline=2,
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
from .p2p import (
|
||||
AsynCommunicator,
|
||||
recv_backward,
|
||||
recv_forward,
|
||||
send_backward,
|
||||
send_backward_and_recv_next_backward_async,
|
||||
send_backward_recv_backward,
|
||||
send_backward_recv_forward,
|
||||
send_forward,
|
||||
send_forward_and_recv_next_forward_async,
|
||||
send_forward_backward_recv_forward_backward,
|
||||
send_forward_recv_backward,
|
||||
send_forward_recv_forward,
|
||||
|
@ -23,4 +26,7 @@ __all__ = [
|
|||
"recv_forward",
|
||||
"send_obj_meta",
|
||||
"recv_obj_meta",
|
||||
"send_backward_and_recv_next_backward_async",
|
||||
"send_forward_and_recv_next_forward_async",
|
||||
"AsynCommunicator",
|
||||
]
|
||||
|
|
|
@ -207,16 +207,13 @@ def recv_forward(
|
|||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
|
||||
"""
|
||||
if gpc.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor, _ = _communicate(
|
||||
recv_prev=True,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
input_tensor, _ = _communicate(
|
||||
recv_prev=True,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor
|
||||
|
||||
|
||||
|
@ -233,16 +230,13 @@ def recv_backward(
|
|||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
|
||||
"""
|
||||
if gpc.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
_, output_tensor_grad = _communicate(
|
||||
recv_next=True,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
_, output_tensor_grad = _communicate(
|
||||
recv_next=True,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
|
@ -253,8 +247,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) ->
|
|||
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not gpc.is_pipeline_last_stage():
|
||||
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||
|
||||
|
||||
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
|
||||
|
@ -264,14 +257,12 @@ def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=Fals
|
|||
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
|
||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
_communicate(
|
||||
object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors
|
||||
)
|
||||
|
||||
_communicate(object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||
|
||||
|
||||
def send_forward_recv_backward(
|
||||
output_tensor, output_grad_shape, recv_next=True, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||
output_tensor, output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next stage in pipeline, while receives the gradient tensor from the
|
||||
|
@ -285,24 +276,21 @@ def send_forward_recv_backward(
|
|||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
||||
"""
|
||||
if gpc.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
_, output_tensor_grad = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
recv_next=recv_next,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
_, output_tensor_grad = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
recv_next=output_grad_shape is not None,
|
||||
recv_next_shape=output_grad_shape,
|
||||
next_rank=next_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False,
|
||||
|
@ -319,24 +307,21 @@ def send_backward_recv_forward(
|
|||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
||||
"""
|
||||
if gpc.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
input_tensor, _ = _communicate(
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
input_tensor, _ = _communicate(
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_prev=input_tensor_shape is not None,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
dtype=dtype,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_forward_recv_forward(
|
||||
output_tensor,
|
||||
input_tensor_shape,
|
||||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
|
@ -356,7 +341,7 @@ def send_forward_recv_forward(
|
|||
"""
|
||||
input_tensor, _ = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
recv_prev=recv_prev,
|
||||
recv_prev=input_tensor_shape is not None,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
|
@ -369,7 +354,6 @@ def send_forward_recv_forward(
|
|||
def send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
output_grad_shape,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
|
@ -389,7 +373,7 @@ def send_backward_recv_backward(
|
|||
"""
|
||||
_, output_tensor_grad = _communicate(
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_next=recv_next,
|
||||
recv_next=output_grad_shape is not None,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
next_rank=next_rank,
|
||||
|
@ -404,8 +388,6 @@ def send_forward_backward_recv_forward_backward(
|
|||
input_tensor_grad,
|
||||
input_tensor_shape,
|
||||
output_grad_shape,
|
||||
recv_prev=True,
|
||||
recv_next=True,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
|
@ -430,8 +412,8 @@ def send_forward_backward_recv_forward_backward(
|
|||
input_tensor, output_tensor_grad = _communicate(
|
||||
object_send_next=output_tensor,
|
||||
object_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
recv_prev=input_tensor_shape is not None,
|
||||
recv_next=output_grad_shape is not None,
|
||||
recv_prev_shape=input_tensor_shape,
|
||||
recv_next_shape=output_grad_shape,
|
||||
prev_rank=prev_rank,
|
||||
|
@ -440,3 +422,159 @@ def send_forward_backward_recv_forward_backward(
|
|||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
return input_tensor, output_tensor_grad
|
||||
|
||||
|
||||
def send_forward_and_recv_next_forward_async(
|
||||
output_tensor,
|
||||
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
dtype: torch.dtype = None,
|
||||
scatter_gather_tensors=False,
|
||||
):
|
||||
"""send forward output to next rank and recv forward input from prev rank"""
|
||||
|
||||
reqs = []
|
||||
tensor_recv_prev = None
|
||||
|
||||
# prepare send opreations
|
||||
if output_tensor is not None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
output_tensor = process_object_to_send(output_tensor, scatter_gather_tensors)
|
||||
|
||||
if isinstance(output_tensor, torch.Tensor):
|
||||
reqs.append(dist.P2POp(dist.isend, output_tensor, next_rank))
|
||||
else:
|
||||
for tensor_to_comm in output_tensor:
|
||||
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, next_rank))
|
||||
|
||||
# prepare receive opreations
|
||||
if recv_prev_shape is not None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
# create receive buffer
|
||||
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
|
||||
recv_prev_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
# generate async receive opterations
|
||||
if isinstance(tensor_recv_prev, torch.Tensor):
|
||||
reqs.append(dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank))
|
||||
else:
|
||||
for tensor_to_comm in tensor_recv_prev:
|
||||
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, prev_rank))
|
||||
|
||||
if len(reqs) > 0:
|
||||
reqs = dist.batch_isend_irecv(reqs)
|
||||
|
||||
# return and do other things
|
||||
yield
|
||||
|
||||
# check communication completed
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Process received data
|
||||
if recv_prev_shape is not None and recv_prev_split:
|
||||
if isinstance(tensor_recv_prev, torch.Tensor):
|
||||
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
||||
else:
|
||||
for index in range(len(tensor_recv_prev)):
|
||||
tensor_recv_prev[index] = (
|
||||
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
|
||||
)
|
||||
|
||||
yield tensor_recv_prev
|
||||
|
||||
|
||||
def send_backward_and_recv_next_backward_async(
|
||||
input_tensor,
|
||||
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||
dtype: torch.dtype = None,
|
||||
scatter_gather_tensors=False,
|
||||
):
|
||||
reqs = []
|
||||
tensor_recv_next = None
|
||||
|
||||
# prepare send opreations
|
||||
if input_tensor is not None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
input_tensor = process_object_to_send(input_tensor, scatter_gather_tensors)
|
||||
|
||||
if isinstance(input_tensor, torch.Tensor):
|
||||
reqs.append(dist.P2POp(dist.isend, input_tensor, prev_rank))
|
||||
else:
|
||||
for tensor_to_comm in input_tensor:
|
||||
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, prev_rank))
|
||||
|
||||
# prepare receive opreations
|
||||
if recv_next_shape is not None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
# create receive buffer
|
||||
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
|
||||
recv_next_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
# generate async receive opreations
|
||||
if isinstance(tensor_recv_next, torch.Tensor):
|
||||
reqs.append(dist.P2POp(dist.irecv, tensor_recv_next, next_rank))
|
||||
else:
|
||||
for tensor_to_comm in tensor_recv_next:
|
||||
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, next_rank))
|
||||
|
||||
if len(reqs) > 0:
|
||||
reqs = dist.batch_isend_irecv(reqs)
|
||||
|
||||
# return and do other things
|
||||
yield
|
||||
|
||||
# check communication completed
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Process received data
|
||||
if recv_next_shape is not None and recv_next_split:
|
||||
if isinstance(tensor_recv_next, torch.Tensor):
|
||||
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
||||
else:
|
||||
for index in range(len(tensor_recv_next)):
|
||||
tensor_recv_next[index] = (
|
||||
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
|
||||
)
|
||||
|
||||
yield tensor_recv_next
|
||||
|
||||
|
||||
class AsynCommunicator:
|
||||
def __init__(
|
||||
self,
|
||||
tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
|
||||
recv_shape: Union[torch.Size, List[torch.Size]],
|
||||
dtype: torch.dtype = None,
|
||||
scatter_gather_tensors=False,
|
||||
forward: bool = True,
|
||||
) -> None:
|
||||
self._need_receive = recv_shape is not None
|
||||
|
||||
if forward:
|
||||
self._coroutine = send_forward_and_recv_next_forward_async(
|
||||
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
else:
|
||||
self._coroutine = send_backward_and_recv_next_backward_async(
|
||||
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
|
||||
)
|
||||
|
||||
@property
|
||||
def need_receive(self) -> bool:
|
||||
return self._need_receive
|
||||
|
||||
def start(self) -> None:
|
||||
next(self._coroutine)
|
||||
|
||||
def wait_and_receive(self) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
received = next(self._coroutine)
|
||||
self._coroutine.close()
|
||||
|
||||
return received
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
||||
|
||||
from functools import wraps
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -19,7 +20,7 @@ def send_meta_helper(obj, next_rank, tensor_kwargs):
|
|||
dist.send(send_shape, next_rank)
|
||||
|
||||
|
||||
def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
||||
def send_obj_meta(obj, next_rank=None):
|
||||
"""Sends obj meta information before sending a specific obj.
|
||||
Since the recipient must know the shape of the obj in p2p communications,
|
||||
meta information of the obj should be sent before communications. This function
|
||||
|
@ -33,22 +34,19 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
|||
Returns:
|
||||
bool: False
|
||||
"""
|
||||
if need_meta:
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
if isinstance(obj, torch.Tensor):
|
||||
send_obj_nums = torch.tensor(1, **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
send_meta_helper(obj, next_rank, tensor_kwargs)
|
||||
else:
|
||||
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
for tensor_to_send in obj:
|
||||
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
|
||||
|
||||
return False
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
if isinstance(obj, torch.Tensor):
|
||||
send_obj_nums = torch.tensor(1, **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
send_meta_helper(obj, next_rank, tensor_kwargs)
|
||||
else:
|
||||
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
for tensor_to_send in obj:
|
||||
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
|
||||
|
||||
|
||||
def recv_meta_helper(prev_rank, tensor_kwargs):
|
||||
|
@ -59,7 +57,7 @@ def recv_meta_helper(prev_rank, tensor_kwargs):
|
|||
return recv_shape
|
||||
|
||||
|
||||
def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
|
||||
def recv_obj_meta(prev_rank=None) -> torch.Size:
|
||||
"""Receives obj meta information before receiving a specific obj.
|
||||
Since the recipient must know the shape of the obj in p2p communications,
|
||||
meta information of the obj should be received before communications. This function
|
||||
|
@ -72,21 +70,20 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
|
|||
Returns:
|
||||
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
|
||||
"""
|
||||
if obj_shape is None:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
recv_obj_nums = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_obj_nums, prev_rank)
|
||||
if recv_obj_nums.item() == 1:
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
recv_obj_nums = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_obj_nums, prev_rank)
|
||||
if recv_obj_nums.item() == 1:
|
||||
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
||||
obj_shape = torch.Size(recv_shape)
|
||||
else:
|
||||
obj_shape = []
|
||||
for _ in range(recv_obj_nums.item()):
|
||||
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
||||
obj_shape = torch.Size(recv_shape)
|
||||
else:
|
||||
obj_shape = []
|
||||
for _ in range(recv_obj_nums.item()):
|
||||
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
||||
obj_shape.append(torch.Size(recv_shape))
|
||||
obj_shape.append(torch.Size(recv_shape))
|
||||
|
||||
return obj_shape
|
||||
|
||||
|
|
|
@ -73,6 +73,17 @@ class NaiveAMPModel(nn.Module):
|
|||
input_ = input_.float()
|
||||
return input_
|
||||
|
||||
def convert_to_fp32(self, out):
|
||||
"""Converts the output to fp32"""
|
||||
if isinstance(out, Tensor):
|
||||
out = self._convert_to_fp32(out)
|
||||
elif isinstance(out, (tuple, list)):
|
||||
out = [self._convert_to_fp32(val) for val in out]
|
||||
elif isinstance(out, dict):
|
||||
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
||||
|
||||
return out
|
||||
|
||||
def _reduce_module_buffer(self):
|
||||
"""
|
||||
All-reduces the buffers (e.g., running stats of batch normalization) across
|
||||
|
@ -121,10 +132,5 @@ class NaiveAMPModel(nn.Module):
|
|||
out = self.model(*args, **kwargs)
|
||||
|
||||
if self._output_to_fp32:
|
||||
if isinstance(out, Tensor):
|
||||
out = self._convert_to_fp32(out)
|
||||
elif isinstance(out, (tuple, list)):
|
||||
out = [self._convert_to_fp32(val) for val in out]
|
||||
elif isinstance(out, dict):
|
||||
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
||||
out = self.convert_to_fp32(out)
|
||||
return out
|
||||
|
|
|
@ -1,5 +1,12 @@
|
|||
from .base_scheduler import BaseScheduler
|
||||
from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook
|
||||
from .no_pipeline_scheduler import NonPipelineScheduler
|
||||
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler
|
||||
|
||||
__all__ = ["BaseScheduler", "NonPipelineScheduler", "InterleavedPipelineScheduler", "PipelineScheduler"]
|
||||
__all__ = [
|
||||
"BaseScheduler",
|
||||
"NonPipelineScheduler",
|
||||
"InterleavedPipelineScheduler",
|
||||
"PipelineScheduler",
|
||||
"SchedulerHook",
|
||||
"SchedulerMetricHook",
|
||||
]
|
||||
|
|
|
@ -4,11 +4,12 @@
|
|||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
||||
|
||||
class BaseScheduler(ABC):
|
||||
|
@ -112,3 +113,85 @@ class BaseScheduler(ABC):
|
|||
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
||||
)
|
||||
|
||||
|
||||
class SchedulerHook(ABC):
|
||||
"""
|
||||
Scheduler Hook.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def before_forward(self, scheduler, inputs) -> None:
|
||||
"""Actions before forward"""
|
||||
|
||||
@abstractmethod
|
||||
def after_forward(self, scheduler, outputs) -> None:
|
||||
"""Actions after forward"""
|
||||
|
||||
@abstractmethod
|
||||
def before_criterion(self, scheduler, outputs, label) -> None:
|
||||
"""Actions before criterion"""
|
||||
|
||||
@abstractmethod
|
||||
def after_criterion(self, scheduler, loss) -> None:
|
||||
"""Actions after criterion"""
|
||||
|
||||
@abstractmethod
|
||||
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||
"""Actions before backward"""
|
||||
|
||||
@abstractmethod
|
||||
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||
"""Actions after backward"""
|
||||
|
||||
@abstractmethod
|
||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||
"""A post helper function"""
|
||||
|
||||
|
||||
class SchedulerMetricHook(SchedulerHook):
|
||||
"""
|
||||
Scheduler Metric Hook.
|
||||
"""
|
||||
|
||||
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
|
||||
self._post_func = metric
|
||||
self._skip = skip
|
||||
|
||||
if skip:
|
||||
# init timer only.
|
||||
timer("fwd")
|
||||
timer("bwd")
|
||||
timer("cal_loss")
|
||||
timer("post_fn")
|
||||
|
||||
def before_forward(self, scheduler, inputs) -> None:
|
||||
if not self._skip:
|
||||
timer("fwd").start()
|
||||
|
||||
def after_forward(self, scheduler, outputs) -> None:
|
||||
if not self._skip:
|
||||
timer("fwd").stop()
|
||||
|
||||
def before_criterion(self, scheduler, outputs, label) -> None:
|
||||
if not self._skip:
|
||||
timer("cal_loss").start()
|
||||
|
||||
def after_criterion(self, scheduler, loss) -> None:
|
||||
if not self._skip:
|
||||
timer("cal_loss").stop()
|
||||
|
||||
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
||||
if not self._skip:
|
||||
timer("bwd").start()
|
||||
|
||||
def after_backward(self, scheduler, inputs_grad) -> None:
|
||||
if not self._skip:
|
||||
timer("bwd").stop()
|
||||
|
||||
def post_helper_func(self, scheduler, outputs, label) -> None:
|
||||
if not self._skip:
|
||||
timer("post_fn").start()
|
||||
if self._post_func is not None:
|
||||
self._post_func(outputs, label)
|
||||
timer("post_fn").stop()
|
||||
|
|
|
@ -3,14 +3,14 @@
|
|||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
from typing import Any, Callable, Iterable
|
||||
from typing import Any, Callable, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.common import conditional_context
|
||||
|
||||
from .base_scheduler import BaseScheduler
|
||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||
|
||||
|
||||
class NonPipelineScheduler(BaseScheduler):
|
||||
|
@ -34,10 +34,17 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
return data, label
|
||||
"""
|
||||
|
||||
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
|
||||
def __init__(
|
||||
self,
|
||||
data_process_func: Callable = None,
|
||||
gradient_accumulation_size: int = 1,
|
||||
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
||||
):
|
||||
self._grad_accum_size = gradient_accumulation_size
|
||||
self._grad_accum_offset = 0
|
||||
|
||||
self._hooks = scheduler_hooks
|
||||
|
||||
super().__init__(data_process_func)
|
||||
|
||||
def pre_processing(self, engine: Engine):
|
||||
|
@ -48,6 +55,10 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
"""
|
||||
pass
|
||||
|
||||
def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
|
||||
for hook in self._hooks:
|
||||
getattr(hook, func_name)(self, *args, **kwargs)
|
||||
|
||||
def _load_accum_batch(self, data: Any, label: Any):
|
||||
"""Loads a batch of data and label for gradient accumulation.
|
||||
|
||||
|
@ -77,7 +88,6 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
scale_loss: int = 1,
|
||||
post_fn: Callable = None,
|
||||
):
|
||||
"""Trains one batch of data.
|
||||
|
||||
|
@ -89,23 +99,27 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
be executed.
|
||||
return_loss (bool, optional): Loss will be returned if True.
|
||||
scale_loss (int, optional): The scale factor for the loss.
|
||||
post_fn (Callable, optional): Call back function after executing data forward output.
|
||||
"""
|
||||
|
||||
# forward
|
||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||
self._call_hooks("before_forward", data)
|
||||
output = self._call_engine(engine, data)
|
||||
self._call_hooks("after_forward", output)
|
||||
|
||||
if post_fn is not None:
|
||||
post_fn(output, label)
|
||||
self._call_hooks("post_helper_func", output, label)
|
||||
|
||||
if return_loss:
|
||||
self._call_hooks("before_criterion", output, label)
|
||||
loss = self._call_engine_criterion(engine, output, label)
|
||||
self._call_hooks("after_criterion", loss)
|
||||
loss /= scale_loss
|
||||
|
||||
# backward
|
||||
if not forward_only:
|
||||
self._call_hooks("before_backward", None, None)
|
||||
engine.backward(loss)
|
||||
self._call_hooks("after_backward", None)
|
||||
|
||||
if not return_loss:
|
||||
loss = None
|
||||
|
@ -119,7 +133,6 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
post_fn: Callable = None,
|
||||
):
|
||||
"""The process function that loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
@ -131,7 +144,6 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
If True, the model is run for the forward pass, else back propagation will be executed.
|
||||
return_loss (bool, optional): Loss will be returned if True.
|
||||
return_output_label (bool, optional): Output and label will be returned if True.
|
||||
post_fn (Callable, optional): Call back function after executing data forward output.
|
||||
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
|
@ -165,7 +177,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
_data, _label = self._load_accum_batch(data, label)
|
||||
|
||||
_output, _loss = self._train_one_batch(
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, post_fn
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
||||
)
|
||||
|
||||
if return_loss:
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -3,7 +3,7 @@
|
|||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
|
||||
|
||||
from typing import Callable, Iterable, Optional, Tuple
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
|
||||
from torch import nn
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
@ -15,12 +15,13 @@ from internlm.core.context import ParallelMode
|
|||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from internlm.core.scheduler.no_pipeline_scheduler import NonPipelineScheduler
|
||||
from internlm.core.scheduler.pipeline_scheduler import (
|
||||
from internlm.core.scheduler import (
|
||||
InterleavedPipelineScheduler,
|
||||
NonPipelineScheduler,
|
||||
PipelineScheduler,
|
||||
get_tensor_shape,
|
||||
SchedulerHook,
|
||||
)
|
||||
from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape
|
||||
from internlm.core.trainer import Trainer
|
||||
from internlm.data.utils import unpack_data
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
|
@ -36,6 +37,7 @@ def initialize_trainer(
|
|||
test_dataloader: Optional[Iterable] = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
beta2_scheduler: Optional[Beta2Scheduler] = None,
|
||||
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
||||
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
|
||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||
loaded into gpc.config.
|
||||
|
@ -92,12 +94,16 @@ def initialize_trainer(
|
|||
if use_interleaved:
|
||||
if isinstance(model, nn.Sequential):
|
||||
model = nn.ModuleList([model])
|
||||
|
||||
communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
||||
scheduler = InterleavedPipelineScheduler(
|
||||
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
||||
num_model_chunks=gpc.config.model.num_chunks,
|
||||
num_chunks=gpc.config.model.num_chunks,
|
||||
dtype=gpc.config.model["dtype"],
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
communication_overlap=communication_overlap,
|
||||
)
|
||||
else:
|
||||
scheduler = PipelineScheduler(
|
||||
|
@ -106,10 +112,13 @@ def initialize_trainer(
|
|||
dtype=gpc.config.model["dtype"],
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
else:
|
||||
scheduler = NonPipelineScheduler(
|
||||
data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation
|
||||
data_process_func=data_fn,
|
||||
gradient_accumulation_size=gpc.config.data.gradient_accumulation,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
|
||||
# initialize engine for trainer
|
||||
|
|
|
@ -7,40 +7,47 @@ from tqdm import tqdm
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.core.scheduler import SchedulerMetricHook
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size):
|
||||
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list):
|
||||
if not gpc.is_using_pp():
|
||||
prev_data_process_func = trainer.schedule.data_process_func
|
||||
prev_grad_accum_size = trainer.schedule._grad_accum_size
|
||||
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
|
||||
prev_metric_hooks = trainer.schedule._hooks
|
||||
try:
|
||||
trainer.schedule.data_process_func = None
|
||||
trainer.schedule._grad_accum_size = grad_accum_size
|
||||
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
|
||||
trainer.schedule._hooks = metric_hook_list
|
||||
yield
|
||||
finally:
|
||||
trainer.schedule.data_process_func = prev_data_process_func
|
||||
trainer.schedule._grad_accum_size = prev_grad_accum_size
|
||||
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
|
||||
trainer.schedule._hooks = prev_metric_hooks
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape):
|
||||
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape, metric_hook_list):
|
||||
if gpc.is_using_pp():
|
||||
pre_data_process_func = trainer.schedule.data_process_func
|
||||
prev_num_microbatches = trainer.schedule.num_microbatches
|
||||
prev_tensor_shape = trainer.schedule.tensor_shape
|
||||
prev_metric_hooks = trainer.schedule._hooks
|
||||
try:
|
||||
trainer.schedule.data_process_func = None
|
||||
trainer.schedule.num_microbatches = num_microbatches
|
||||
trainer.schedule.tensor_shape = tensor_shape
|
||||
trainer.schedule._hooks = metric_hook_list
|
||||
yield
|
||||
finally:
|
||||
trainer.schedule.data_process_func = pre_data_process_func
|
||||
trainer.schedule.num_microbatches = prev_num_microbatches
|
||||
trainer.schedule.tensor_shape = prev_tensor_shape
|
||||
trainer.schedule._hooks = prev_metric_hooks
|
||||
|
||||
|
||||
def evaluate_on_val_dls(
|
||||
|
@ -49,7 +56,6 @@ def evaluate_on_val_dls(
|
|||
writer,
|
||||
logger,
|
||||
step_count,
|
||||
tokenizer=None,
|
||||
update_panel: bool = False,
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -66,8 +72,9 @@ def evaluate_on_val_dls(
|
|||
device=torch.cuda.current_device(),
|
||||
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
||||
dp_pg=gpc.get_group(ParallelMode.DATA),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
|
||||
|
||||
val_loss = 0
|
||||
val_idx = -1
|
||||
for val_idx, batch in tqdm(
|
||||
|
@ -88,10 +95,13 @@ def evaluate_on_val_dls(
|
|||
)
|
||||
|
||||
with switch_evaluation_pipeline_scheduler(
|
||||
trainer=trainer, num_microbatches=num_microbatches, tensor_shape=tensor_shape
|
||||
trainer=trainer,
|
||||
num_microbatches=num_microbatches,
|
||||
tensor_shape=tensor_shape,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
else:
|
||||
total_val_bsz = len(batch[1])
|
||||
|
@ -100,38 +110,42 @@ def evaluate_on_val_dls(
|
|||
grad_accum_batch_size = data_cfg.micro_bsz
|
||||
|
||||
with switch_evaluation_no_pipeline_scheduler(
|
||||
trainer=trainer, grad_accum_size=grad_accum_size, grad_accum_batch_size=grad_accum_batch_size
|
||||
trainer=trainer,
|
||||
grad_accum_size=grad_accum_size,
|
||||
grad_accum_batch_size=grad_accum_batch_size,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
if verbose:
|
||||
val_loss += loss.item()
|
||||
|
||||
assert val_idx != -1
|
||||
dist.barrier()
|
||||
val_res = val_metric.get_metric()
|
||||
|
||||
val_res = val_metric.get_metric()
|
||||
if verbose and len(val_dl) != 0:
|
||||
val_loss = val_loss / (val_idx + 1 + 1e-6)
|
||||
infos = {
|
||||
"step": step_count,
|
||||
f"val/{val_name}_loss": val_loss,
|
||||
f"val/{val_name}_acc": val_res["acc"],
|
||||
f"val/{val_name}_plex": val_res["perplexity"],
|
||||
}
|
||||
val_metric = {
|
||||
"step": step_count,
|
||||
"val_loss": val_loss,
|
||||
"val_acc": val_res["acc"],
|
||||
"val_perplexity": val_res["perplexity"],
|
||||
}
|
||||
|
||||
for key, value in infos.items():
|
||||
writer.add_scalar(key=key, value=value, step=step_count)
|
||||
infos["step"] = step_count
|
||||
|
||||
if update_panel:
|
||||
logger.info(
|
||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
|
||||
extra=val_metric,
|
||||
extra={
|
||||
"step": step_count,
|
||||
"val_loss": val_loss,
|
||||
"val_acc": val_res["acc"],
|
||||
"val_perplexity": val_res["perplexity"],
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
|
|
64
train.py
64
train.py
|
@ -16,6 +16,7 @@ import internlm
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.core.scheduler import SchedulerMetricHook
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
|
||||
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
|
||||
|
@ -109,12 +110,25 @@ def initialize_model():
|
|||
"""
|
||||
|
||||
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=is_no_pp_or_last_stage(),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
)
|
||||
if isinstance(model, nn.ModuleList):
|
||||
model = nn.ModuleList(
|
||||
[
|
||||
NaiveAMPModel(
|
||||
model=_m,
|
||||
output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
)
|
||||
for _m in model
|
||||
]
|
||||
)
|
||||
else:
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=is_no_pp_or_last_stage(),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
)
|
||||
|
||||
# This sync is very important, cause the model weights kept in optimizer are copied
|
||||
# from the origin parameters in the memory, so we should make sure the dp sync
|
||||
|
@ -500,19 +514,6 @@ def main(args):
|
|||
if load_optimizer:
|
||||
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
|
||||
|
||||
# initialize trainer
|
||||
trainer, train_dl, _, _ = internlm.initialize_trainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dl,
|
||||
lr_scheduler=lr_scheduler,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
)
|
||||
|
||||
# initialize the batch skipper
|
||||
batch_skipper = BatchSkipper(skip_batches)
|
||||
|
||||
# initialize metric for calculating accuracy and perplexity
|
||||
metric = AccPerplex(
|
||||
device=torch.cuda.current_device(),
|
||||
|
@ -521,6 +522,27 @@ def main(args):
|
|||
dataset_types=dataset_types,
|
||||
)
|
||||
|
||||
# initialize trainer
|
||||
scheduler_hooks = [
|
||||
SchedulerMetricHook(
|
||||
metric=metric,
|
||||
skip=gpc.is_using_pp() and gpc.config.parallel["pipeline"].get("interleaved_overlap", False),
|
||||
),
|
||||
]
|
||||
|
||||
trainer, train_dl, _, _ = internlm.initialize_trainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dl,
|
||||
lr_scheduler=lr_scheduler,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
|
||||
# initialize the batch skipper
|
||||
batch_skipper = BatchSkipper(skip_batches)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# transfer the train data loader into train data iterator
|
||||
|
@ -558,9 +580,7 @@ def main(args):
|
|||
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=False, return_loss=True, return_output_label=False, post_fn=metric
|
||||
)
|
||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
|
|
Loading…
Reference in New Issue