mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix zerobubble pp for shardformer type input;
parent
9bc3b6e220
commit
3dbad102cf
|
@ -131,6 +131,16 @@ def retain_grad(x: Any) -> None:
|
|||
x.retain_grad()
|
||||
|
||||
|
||||
def require_grad(x: Any) -> None:
|
||||
"""Call require_grad on a tensor.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be called.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor) and x.requires_grad:
|
||||
x.requires_grad_()
|
||||
|
||||
|
||||
def detach(x: Any) -> Any:
|
||||
"""Call detach() on a tensor.
|
||||
|
||||
|
@ -145,6 +155,34 @@ def detach(x: Any) -> Any:
|
|||
return x
|
||||
|
||||
|
||||
def clone(x: Any) -> Any:
|
||||
"""Call clone() on a tensor.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be called.
|
||||
|
||||
Returns:
|
||||
Any: The cloned object.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.clone()
|
||||
return x
|
||||
|
||||
|
||||
def deallocate(x: Any) -> Any:
|
||||
"""Call deallocate() on a tensor.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be called.
|
||||
|
||||
Returns:
|
||||
Any: The deallocate .data object.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.data.untyped_storage().resize_(0)
|
||||
return x
|
||||
|
||||
|
||||
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
|
||||
"""Merge micro batches into a batch.
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
|
|||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device
|
||||
from ._utils import clone, detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
from .base import PipelineSchedule
|
||||
|
||||
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
||||
|
@ -39,6 +39,20 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
|
|||
out.data.untyped_storage().resize_(0)
|
||||
|
||||
|
||||
def require_grad(tensor):
|
||||
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
|
||||
|
||||
This method should be called right after the output tensor has been
|
||||
sent to the next pipeline stage. At this point, the output tensor is
|
||||
only useful for its '.grad_fn' field, and not its '.data'.
|
||||
"""
|
||||
if tensor is None:
|
||||
return
|
||||
assert isinstance(tensor, torch.Tensor), "expected Tensor, found %s." % type(tensor).__name__
|
||||
assert tensor._base is None, "counter-productive to free a view of another tensor."
|
||||
tensor.requires_grad_()
|
||||
|
||||
|
||||
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -409,6 +423,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
micro_batch: Optional[dict],
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
|
@ -427,18 +442,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
# Load input ids, attention mask and labels
|
||||
# micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
|
||||
# for the first stage, input_obj is None
|
||||
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
|
||||
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
|
||||
# Only attention_mask from micro_batch is used
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
# fwd calculate
|
||||
output_obj = model_chunk[model_chunk_id](input_obj)
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
# fwd for ModuleList model
|
||||
if input_obj is None:
|
||||
output_obj = model_chunk[model_chunk_id](**micro_batch)
|
||||
else:
|
||||
output_obj = model_chunk[model_chunk_id](**input_obj)
|
||||
else:
|
||||
# fwd for shardformer
|
||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
||||
|
||||
# last layer in model
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
loss = criterion(output_obj) / self.num_microbatch
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
|
@ -472,19 +496,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# calculate bwd b step ; only dx = w*dy;
|
||||
|
||||
# Retain the grad on the input_obj.
|
||||
if input_obj is None:
|
||||
return None
|
||||
else:
|
||||
tree_map(retain_grad, input_obj)
|
||||
input_obj_ = input_obj["hidden_states"]
|
||||
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# loss backward; output_obj is loss; so output_obj_grad should be None
|
||||
assert output_obj_grad is None
|
||||
|
||||
output_obj_ = output_obj
|
||||
else:
|
||||
output_obj_ = output_obj["hidden_states"]
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
tensor=output_obj_,
|
||||
grad=output_obj_grad,
|
||||
inputs=input_obj,
|
||||
inputs=input_obj_,
|
||||
retain_graph=True,
|
||||
)
|
||||
return input_obj.grad
|
||||
return input_obj_.grad
|
||||
|
||||
def backward_w_step(
|
||||
self,
|
||||
|
@ -511,8 +541,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# loss backward; output_obj is loss
|
||||
output_obj_grad = None
|
||||
output_obj_ = output_obj
|
||||
else:
|
||||
output_obj_ = output_obj["hidden_states"]
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
tensor=output_obj_,
|
||||
grad=output_obj_grad,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
retain_graph=False,
|
||||
|
@ -543,9 +576,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
# Step1: recv fwd
|
||||
if model_chunk_id == 0:
|
||||
# is first stage; get input from func param
|
||||
# is first stage; get input from microbatch
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj = micro_batch
|
||||
input_obj = None
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
else:
|
||||
|
@ -557,45 +590,68 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
# Here, let input_obj.requires_grad_()
|
||||
tree_map(torch.Tensor.requires_grad_, input_obj)
|
||||
if input_obj is not None:
|
||||
tree_map(require_grad, input_obj)
|
||||
|
||||
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
|
||||
# tree_map(torch.Tensor.requires_grad_, micro_batch)
|
||||
|
||||
# Step2: fwd step
|
||||
output_obj = self.forward_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
micro_batch=micro_batch,
|
||||
input_obj=input_obj,
|
||||
criterion=criterion,
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
# Step3: deallocate output for bwd b & w; (do not detach output)
|
||||
deallocate_output_obj = tree_map(clone, output_obj)
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# We should not deallocate bwd LOSS
|
||||
pass
|
||||
else:
|
||||
# deallocate output
|
||||
tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), deallocate_output_obj)
|
||||
|
||||
# add input and output object for backward b
|
||||
if input_obj is not None:
|
||||
self.input_tensors[model_chunk_id].append(input_obj)
|
||||
else:
|
||||
self.input_tensors[model_chunk_id].append(micro_batch)
|
||||
|
||||
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||
# Do not deallocate loss, deallocate other output_obj;
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
|
||||
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
|
||||
else:
|
||||
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
|
||||
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
|
||||
|
||||
# Step4: detach output for send fwd;
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# We should not detach bwd LOSS
|
||||
pass
|
||||
else:
|
||||
detached_output_obj = output_obj.clone().detach()
|
||||
# detach output
|
||||
output_obj = tree_map(detach, output_obj)
|
||||
|
||||
# Step3: send fwd
|
||||
# add output to send_fwd_buffer
|
||||
if model_chunk_id == 0:
|
||||
if model_chunk_id == 0: # chunk 0
|
||||
# is last stage; send to local_send_forward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.local_send_forward_buffer.append(detached_output_obj)
|
||||
self.local_send_forward_buffer.append(output_obj)
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||
else:
|
||||
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
|
||||
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||
else: # chunk 1
|
||||
# is first stage; end of fwd; do nothing
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
pass
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||
|
||||
# add input and output object for backward b
|
||||
self.input_tensors[model_chunk_id].append(input_obj)
|
||||
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||
deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True)
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
# add output object for backward w
|
||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||
|
||||
def schedule_b(
|
||||
self,
|
||||
|
@ -603,9 +659,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
optimizer: OptimizerWrapper,
|
||||
# input_obj: Optional[dict],
|
||||
# output_obj: Union[dict, torch.Tensor],
|
||||
# output_obj_grad: Optional[dict],
|
||||
):
|
||||
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
|
||||
|
||||
|
@ -616,7 +669,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
Returns:
|
||||
Nothing.
|
||||
"""
|
||||
|
||||
# Step1: recv bwd
|
||||
if model_chunk_id == 0:
|
||||
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
||||
|
@ -645,7 +697,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# we save output_tensor_grad here
|
||||
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||
|
||||
# _wait_p2p(recv_bwd_handles)
|
||||
# Step2: bwd step
|
||||
input_object_grad = self.backward_b_step(
|
||||
model_chunk=model_chunk,
|
||||
|
@ -777,8 +828,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
communication_func(scheduled_node.chunk)
|
||||
|
||||
if scheduled_node.type == "F":
|
||||
elif scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
@ -16,7 +18,8 @@ from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
|||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
# from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
|
@ -24,10 +27,32 @@ class MlpModel(nn.Module):
|
|||
super().__init__()
|
||||
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
hidden_states = layer(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def pp_linear_fwd(
|
||||
forward,
|
||||
data: torch.Tensor = None,
|
||||
hidden_states: torch.Tensor = None,
|
||||
stage_mgr: PipelineStageManager = None,
|
||||
model_chunk_id: int = None,
|
||||
):
|
||||
with stage_mgr.switch_model_chunk_id(model_chunk_id):
|
||||
# fwd end
|
||||
if stage_mgr.is_first_stage() and model_chunk_id == 1:
|
||||
return forward(hidden_states)
|
||||
# fwd start
|
||||
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
|
||||
return {"hidden_states": forward(hidden_states)}
|
||||
# fwd middle
|
||||
else:
|
||||
return {"hidden_states": forward(hidden_states)}
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
|
@ -510,15 +535,15 @@ def run_fwd_bwd_iter_input(test_config):
|
|||
"precision": "bf16",
|
||||
"num_model_chunk": 2,
|
||||
},
|
||||
{
|
||||
"batch_size": 8,
|
||||
"tp_size": 1,
|
||||
"pp_size": 4,
|
||||
"num_microbatches": 8,
|
||||
"zero_stage": 1,
|
||||
"precision": "bf16",
|
||||
"num_model_chunk": 2,
|
||||
},
|
||||
# {
|
||||
# "batch_size": 8,
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 4,
|
||||
# "num_microbatches": 8,
|
||||
# "zero_stage": 1,
|
||||
# "precision": "bf16",
|
||||
# "num_model_chunk": 2,
|
||||
# },
|
||||
],
|
||||
)
|
||||
def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||
|
@ -562,6 +587,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
|
||||
# init loss func
|
||||
def criterion(x, *args, **kwargs):
|
||||
x = x["hidden_states"]
|
||||
return (x * x).mean()
|
||||
|
||||
def criterion_base(x, *args, **kwargs):
|
||||
return (x * x).mean()
|
||||
|
||||
# init model and input
|
||||
|
@ -572,9 +601,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
||||
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||
|
||||
input_base = [t.clone() for t in data_iter]
|
||||
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||
data_iter = {"hidden_states": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
|
||||
# input_base = [t.clone() for t in data_iter]
|
||||
input_base = {k: v.clone() for k, v in data_iter.items()}
|
||||
model_base = deepcopy(model)
|
||||
|
||||
if rank == 0:
|
||||
|
@ -582,24 +612,44 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 0 or idx == 7:
|
||||
sub_model._forward = sub_model.forward
|
||||
sub_model.forward = MethodType(
|
||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||
sub_model._forward,
|
||||
)
|
||||
local_chunk.append(sub_model)
|
||||
elif rank == 1:
|
||||
# layer 1 & 6 to chunk 1 on rank1
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 1 or idx == 6:
|
||||
sub_model._forward = sub_model.forward
|
||||
sub_model.forward = MethodType(
|
||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||
sub_model._forward,
|
||||
)
|
||||
local_chunk.append(sub_model)
|
||||
elif rank == 2:
|
||||
# layer 2 & 5 to chunk 2 on rank2
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 2 or idx == 5:
|
||||
sub_model._forward = sub_model.forward
|
||||
sub_model.forward = MethodType(
|
||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||
sub_model._forward,
|
||||
)
|
||||
local_chunk.append(sub_model)
|
||||
else:
|
||||
# layer 3 & 4 to chunk 3 on rank3
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 3 or idx == 4:
|
||||
sub_model._forward = sub_model.forward
|
||||
sub_model.forward = MethodType(
|
||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||
sub_model._forward,
|
||||
)
|
||||
local_chunk.append(sub_model)
|
||||
|
||||
# init optimizer
|
||||
|
@ -612,7 +662,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
torch.cuda.synchronize()
|
||||
result = scheduler.forward_backward_step(
|
||||
model_chunk=local_chunk,
|
||||
data_iter=iter(data_iter),
|
||||
data_iter=iter([data_iter]),
|
||||
criterion=criterion,
|
||||
optimizer=optimizer_pp,
|
||||
return_loss=True,
|
||||
|
@ -643,8 +693,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
# Fwd bwd for base
|
||||
##########################
|
||||
# fwd & bwd
|
||||
output_base = model_base(input_base[0])
|
||||
loss_base = criterion(output_base)
|
||||
output_base = model_base(input_base["hidden_states"])
|
||||
loss_base = criterion_base(output_base)
|
||||
loss_base.backward()
|
||||
optimizer_base.step()
|
||||
|
||||
|
@ -654,7 +704,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
# only chunk 1 stage 0 hold loss and output
|
||||
if rank == 0:
|
||||
assert_close(result["loss"], loss_base)
|
||||
assert_close(result["outputs"], output_base)
|
||||
assert_close(result["outputs"]["hidden_states"], output_base)
|
||||
|
||||
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
|
||||
##########################
|
||||
|
@ -727,6 +777,7 @@ def run_with_hybridplugin(test_config):
|
|||
{
|
||||
"pp_style": "zbv",
|
||||
"tp_size": 1,
|
||||
"ep_size": 1,
|
||||
"pp_size": 4,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
|
@ -737,7 +788,7 @@ def run_with_hybridplugin(test_config):
|
|||
)
|
||||
def run_with_moehybridplugin(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||
test_config["use_lazy_init"] = False
|
||||
# test_config["use_lazy_init"] = False
|
||||
test_config["initial_scale"] = 2**16
|
||||
model_list = [
|
||||
"transformers_bert",
|
||||
|
@ -749,6 +800,7 @@ def run_with_moehybridplugin(test_config):
|
|||
# base param
|
||||
model = model_fn()
|
||||
data = data_gen_fn()
|
||||
print(f"data {data}")
|
||||
criterion = loss_fn
|
||||
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
|
||||
|
||||
|
@ -787,7 +839,7 @@ def run_with_moehybridplugin(test_config):
|
|||
# plugin = MoeHybridParallelPlugin(
|
||||
# **test_config
|
||||
# )
|
||||
# model_pp, optimizer_pp, criterion, data_pp = plugin.configure(
|
||||
# model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure(
|
||||
# model = model_pp,
|
||||
# optimizer = optimizer_pp,
|
||||
# criterion = criterion,
|
||||
|
@ -806,16 +858,34 @@ def run_with_moehybridplugin(test_config):
|
|||
|
||||
# TODO:6) support booster & Hybrid base 4)
|
||||
|
||||
|
||||
# TODO:7) support booster & MoEHybrid base 4)
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"pp_style": "zbv",
|
||||
"tp_size": 1,
|
||||
"ep_size": 1,
|
||||
"pp_size": 4,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"precision": "bf16",
|
||||
"num_model_chunks": 2,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_with_booster_moehybridplugin(test_config):
|
||||
pass
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
# run_fwd_bwd_iter_input()
|
||||
# run_fwd_bwd_vschedule_with_optim()
|
||||
run_fwd_bwd_vschedule_with_optim()
|
||||
# run_with_moehybridplugin()
|
||||
run_with_moehybridplugin()
|
||||
# run_with_booster_moehybridplugin()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue