mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #6069 from duanjunwen/dev/zero_bubble
[HotFix] Fix stage_index in zerobubble test;pull/6075/head
commit
b804fdc297
|
@ -430,8 +430,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
# fwd calculate
|
# fwd calculate
|
||||||
internal_inputs = {} if input_obj is None else input_obj
|
internal_inputs = {} if input_obj is None else input_obj
|
||||||
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||||
|
|
||||||
# last layer in model
|
# last layer in model
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
@ -449,7 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
micro_batch: Optional[dict],
|
# micro_batch: Optional[dict],
|
||||||
input_obj: Optional[dict],
|
input_obj: Optional[dict],
|
||||||
output_obj: Union[dict, torch.Tensor],
|
output_obj: Union[dict, torch.Tensor],
|
||||||
output_obj_grad: Optional[dict],
|
output_obj_grad: Optional[dict],
|
||||||
|
@ -478,11 +478,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
output_obj_ = []
|
output_obj_ = []
|
||||||
output_obj_grad_ = []
|
output_obj_grad_ = []
|
||||||
|
|
||||||
# For chunk 0 stage 0, use micro_batch as input_obj_
|
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
|
||||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
input_obj_, _ = tree_flatten(micro_batch)
|
return None
|
||||||
output_obj_, _ = tree_flatten(output_obj) # y
|
|
||||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
|
||||||
|
|
||||||
# For loss backward; output_obj is loss; output_obj_grad should be None
|
# For loss backward; output_obj is loss; output_obj_grad should be None
|
||||||
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
@ -497,6 +495,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
output_obj_, _ = tree_flatten(output_obj) # y
|
output_obj_, _ = tree_flatten(output_obj) # y
|
||||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||||
|
|
||||||
|
# filter item which is not torch.Tensor
|
||||||
|
input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
|
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
|
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
|
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj_,
|
tensor=output_obj_,
|
||||||
grad=output_obj_grad_,
|
grad=output_obj_grad_,
|
||||||
|
@ -507,9 +510,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# Format output_obj_grad
|
# Format output_obj_grad
|
||||||
input_obj_grad = {}
|
input_obj_grad = {}
|
||||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
for k, v in micro_batch.items():
|
pass
|
||||||
if isinstance(v, torch.Tensor) and v.grad is not None:
|
|
||||||
input_obj_grad[k] = v.grad
|
|
||||||
else:
|
else:
|
||||||
for k, v in input_obj.items():
|
for k, v in input_obj.items():
|
||||||
if isinstance(v, torch.Tensor) and v.grad is not None:
|
if isinstance(v, torch.Tensor) and v.grad is not None:
|
||||||
|
@ -550,10 +551,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
output_obj_, _ = tree_flatten(output_obj) # y
|
output_obj_, _ = tree_flatten(output_obj) # y
|
||||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||||
|
|
||||||
|
# filter item which is not torch.Tensor
|
||||||
|
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
|
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
||||||
|
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj_,
|
tensor=output_obj_,
|
||||||
grad=output_obj_grad_,
|
grad=output_obj_grad_,
|
||||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
inputs=list(model_chunk.parameters()),
|
||||||
retain_graph=False,
|
retain_graph=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -634,7 +639,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
tree_map(release_tensor_data, output_obj)
|
tree_map(release_tensor_data, output_obj)
|
||||||
|
|
||||||
# add input and output object for backward b
|
# add input and output object for backward b
|
||||||
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
|
self.input_tensors[model_chunk_id].append(input_obj)
|
||||||
|
|
||||||
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||||
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
||||||
|
@ -692,7 +697,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# get input and output object from buffer;
|
# get input and output object from buffer;
|
||||||
micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0)
|
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# save output_tensor_grad for dw
|
# save output_tensor_grad for dw
|
||||||
|
@ -708,7 +713,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
micro_batch=micro_batch,
|
|
||||||
input_obj=input_obj,
|
input_obj=input_obj,
|
||||||
output_obj=output_obj,
|
output_obj=output_obj,
|
||||||
output_obj_grad=output_tensor_grad,
|
output_obj_grad=output_tensor_grad,
|
||||||
|
|
|
@ -26,6 +26,7 @@ class PipelineStageManager:
|
||||||
pg_mesh: ProcessGroupMesh,
|
pg_mesh: ProcessGroupMesh,
|
||||||
pipeline_axis: int,
|
pipeline_axis: int,
|
||||||
enable_interleave: bool = False,
|
enable_interleave: bool = False,
|
||||||
|
use_zbv: bool = False,
|
||||||
num_model_chunks: int = 1,
|
num_model_chunks: int = 1,
|
||||||
num_layers_per_stage: Optional[List[int]] = None,
|
num_layers_per_stage: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -49,6 +50,7 @@ class PipelineStageManager:
|
||||||
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
|
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
|
||||||
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
|
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
|
||||||
self.is_interleave = enable_interleave
|
self.is_interleave = enable_interleave
|
||||||
|
self.use_zbv = use_zbv
|
||||||
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
||||||
self.num_model_chunks: int = num_model_chunks
|
self.num_model_chunks: int = num_model_chunks
|
||||||
# for shardformer, hold stage indices of model
|
# for shardformer, hold stage indices of model
|
||||||
|
@ -85,6 +87,16 @@ class PipelineStageManager:
|
||||||
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
|
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
|
||||||
|
|
||||||
stage_indices = []
|
stage_indices = []
|
||||||
|
if self.use_zbv:
|
||||||
|
stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]])
|
||||||
|
stage_indices.append(
|
||||||
|
[
|
||||||
|
num_layers_per_stage_accumulated[2 * num_stages - stage - 1],
|
||||||
|
num_layers_per_stage_accumulated[2 * num_stages - stage],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return stage_indices
|
||||||
|
|
||||||
for model_chunk in range(num_model_chunks):
|
for model_chunk in range(num_model_chunks):
|
||||||
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
|
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
|
||||||
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
|
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
|
||||||
|
|
|
@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager):
|
||||||
self.is_interleave = False
|
self.is_interleave = False
|
||||||
self.num_layers_per_stage = None
|
self.num_layers_per_stage = None
|
||||||
self.num_model_chunks = 1
|
self.num_model_chunks = 1
|
||||||
|
self.use_zbv = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_stages(self):
|
def num_stages(self):
|
||||||
|
|
|
@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager):
|
||||||
self.is_interleave = False
|
self.is_interleave = False
|
||||||
self.num_layers_per_stage = None
|
self.num_layers_per_stage = None
|
||||||
self.num_model_chunks = 1
|
self.num_model_chunks = 1
|
||||||
|
self.use_zbv = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_stages(self):
|
def num_stages(self):
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from types import MethodType
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -22,36 +21,50 @@ from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
class MlpModel(nn.Module):
|
||||||
def __init__(self, in_dim, out_dim, num_layers):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
num_layers,
|
||||||
|
stage_index=None,
|
||||||
|
stage_mgr: PipelineStageManager = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
|
self.layers = nn.Sequential(*[nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
|
||||||
):
|
|
||||||
for layer in self.layers:
|
|
||||||
hidden_states = layer(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def pp_linear_fwd(
|
|
||||||
forward,
|
|
||||||
data: torch.Tensor = None,
|
data: torch.Tensor = None,
|
||||||
hidden_states: torch.Tensor = None,
|
hidden_states: torch.Tensor = None,
|
||||||
|
stage_index=None,
|
||||||
stage_mgr: PipelineStageManager = None,
|
stage_mgr: PipelineStageManager = None,
|
||||||
model_chunk_id: int = None,
|
model_chunk_id: int = None,
|
||||||
):
|
):
|
||||||
with stage_mgr.switch_model_chunk_id(model_chunk_id):
|
if stage_mgr is None:
|
||||||
|
hidden_states = data
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
else:
|
||||||
|
# Set not used layer to None
|
||||||
|
held_layers = self.layers[stage_index[0] : stage_index[1]]
|
||||||
|
|
||||||
# fwd end
|
# fwd end
|
||||||
if stage_mgr.is_first_stage() and model_chunk_id == 1:
|
if stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 1:
|
||||||
return forward(hidden_states)
|
return held_layers(hidden_states)
|
||||||
# fwd start
|
# fwd start
|
||||||
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
|
elif stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 0:
|
||||||
return {"hidden_states": forward(data)}
|
return {"hidden_states": held_layers(data)}
|
||||||
# fwd middle
|
# fwd middle
|
||||||
else:
|
else:
|
||||||
return {"hidden_states": forward(hidden_states)}
|
return {"hidden_states": held_layers(hidden_states)}
|
||||||
|
|
||||||
|
|
||||||
|
def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups):
|
||||||
|
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
|
||||||
|
if key_base == key_pp:
|
||||||
|
if key_base != "params":
|
||||||
|
assert val_base == val_pp
|
||||||
|
|
||||||
|
|
||||||
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
|
@ -554,7 +567,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
num_model_chunk = test_config["num_model_chunk"]
|
num_model_chunk = test_config["num_model_chunk"]
|
||||||
# stage_manager
|
# stage_manager
|
||||||
stage_manager = PipelineStageManager(
|
stage_manager = PipelineStageManager(
|
||||||
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
|
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk, use_zbv=True
|
||||||
)
|
)
|
||||||
|
|
||||||
h, a, s = 4096, 32, 1024
|
h, a, s = 4096, 32, 1024
|
||||||
|
@ -600,67 +613,27 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
||||||
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
||||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||||
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
|
||||||
data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
|
data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
|
||||||
# input_base = [t.clone() for t in data_iter]
|
|
||||||
input_base = {k: v.clone() for k, v in data_iter.items()}
|
input_base = {k: v.clone() for k, v in data_iter.items()}
|
||||||
model_base = deepcopy(model)
|
model_base = deepcopy(model)
|
||||||
|
model_pp = deepcopy(model)
|
||||||
|
layers_per_stage = stage_manager.distribute_layers(len(model.layers))
|
||||||
|
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
|
|
||||||
if rank == 0:
|
model_pp._forward = model_pp.forward
|
||||||
# layer 0 & 7 to chunk 0 on rank0
|
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager)
|
||||||
for idx, sub_model in enumerate(model.layers):
|
|
||||||
if idx == 0 or idx == 7:
|
|
||||||
sub_model._forward = sub_model.forward
|
|
||||||
sub_model.forward = MethodType(
|
|
||||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
|
||||||
sub_model._forward,
|
|
||||||
)
|
|
||||||
local_chunk.append(sub_model)
|
|
||||||
elif rank == 1:
|
|
||||||
# layer 1 & 6 to chunk 1 on rank1
|
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
|
||||||
for idx, sub_model in enumerate(model.layers):
|
|
||||||
if idx == 1 or idx == 6:
|
|
||||||
sub_model._forward = sub_model.forward
|
|
||||||
sub_model.forward = MethodType(
|
|
||||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
|
||||||
sub_model._forward,
|
|
||||||
)
|
|
||||||
local_chunk.append(sub_model)
|
|
||||||
elif rank == 2:
|
|
||||||
# layer 2 & 5 to chunk 2 on rank2
|
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
|
||||||
for idx, sub_model in enumerate(model.layers):
|
|
||||||
if idx == 2 or idx == 5:
|
|
||||||
sub_model._forward = sub_model.forward
|
|
||||||
sub_model.forward = MethodType(
|
|
||||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
|
||||||
sub_model._forward,
|
|
||||||
)
|
|
||||||
local_chunk.append(sub_model)
|
|
||||||
else:
|
|
||||||
# layer 3 & 4 to chunk 3 on rank3
|
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
|
||||||
for idx, sub_model in enumerate(model.layers):
|
|
||||||
if idx == 3 or idx == 4:
|
|
||||||
sub_model._forward = sub_model.forward
|
|
||||||
sub_model.forward = MethodType(
|
|
||||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
|
||||||
sub_model._forward,
|
|
||||||
)
|
|
||||||
local_chunk.append(sub_model)
|
|
||||||
|
|
||||||
# init optimizer
|
# init optimizer
|
||||||
optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5)
|
optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5)
|
||||||
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5))
|
optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
|
||||||
|
|
||||||
after_init_memory = torch.cuda.memory_allocated() / 1024**3
|
after_init_memory = torch.cuda.memory_allocated() / 1024**3
|
||||||
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
|
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
result = scheduler.forward_backward_step(
|
result = scheduler.forward_backward_step(
|
||||||
model_chunk=local_chunk,
|
model_chunk=model_pp,
|
||||||
data_iter=iter([data_iter]),
|
data_iter=iter([data_iter]),
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer_pp,
|
optimizer=optimizer_pp,
|
||||||
|
@ -694,7 +667,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
# Fwd bwd for base
|
# Fwd bwd for base
|
||||||
##########################
|
##########################
|
||||||
# fwd & bwd
|
# fwd & bwd
|
||||||
output_base = model_base(input_base["data"])
|
# output_base = model_base(input_base["data"])
|
||||||
|
output_base = model_base.forward(data=input_base["data"])
|
||||||
loss_base = criterion_base(output_base)
|
loss_base = criterion_base(output_base)
|
||||||
loss_base.backward()
|
loss_base.backward()
|
||||||
optimizer_base.step()
|
optimizer_base.step()
|
||||||
|
@ -707,63 +681,53 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
assert_close(result["loss"], loss_base)
|
assert_close(result["loss"], loss_base)
|
||||||
assert_close(result["outputs"]["hidden_states"], output_base)
|
assert_close(result["outputs"]["hidden_states"], output_base)
|
||||||
|
|
||||||
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
|
# ##########################
|
||||||
##########################
|
# # assert weight & optim state
|
||||||
# assert weight
|
# ##########################
|
||||||
##########################
|
|
||||||
if rank == 0:
|
|
||||||
# layer 0
|
|
||||||
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
|
|
||||||
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
|
|
||||||
# layer 7
|
|
||||||
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
|
|
||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
|
|
||||||
if rank == 1:
|
|
||||||
# layer 1
|
|
||||||
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
|
|
||||||
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
|
|
||||||
# layer 6
|
|
||||||
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
|
|
||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
|
|
||||||
if rank == 2:
|
|
||||||
# layer 2
|
|
||||||
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
|
|
||||||
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
|
|
||||||
# layer 5
|
|
||||||
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
|
|
||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
|
|
||||||
if rank == 3:
|
|
||||||
# layer 3
|
|
||||||
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
|
|
||||||
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
|
|
||||||
# layer 4
|
|
||||||
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
|
|
||||||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
|
||||||
|
|
||||||
##########################
|
|
||||||
# assert optim state
|
|
||||||
##########################
|
|
||||||
optim_base_state = optimizer_base.state_dict()["state"]
|
optim_base_state = optimizer_base.state_dict()["state"]
|
||||||
optim_pp_state = optimizer_pp.state_dict()["state"]
|
optim_pp_state = optimizer_pp.state_dict()["state"]
|
||||||
optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0]
|
optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0]
|
||||||
optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0]
|
optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0]
|
||||||
# if rank == 0:
|
|
||||||
# print(f"optim_base_state {optim_base_state}")
|
|
||||||
|
|
||||||
# assert param group
|
if rank == 0:
|
||||||
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
|
# layer 0
|
||||||
if key_base == key_pp:
|
assert_close(model_pp.layers[0].weight, model_base.layers[0].weight)
|
||||||
if key_base != "params":
|
assert_close(model_pp.layers[0].weight.grad, model_base.layers[0].weight.grad)
|
||||||
assert val_base == val_pp
|
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[0]["momentum_buffer"])
|
||||||
else:
|
# layer 7
|
||||||
# BUG:
|
assert_close(model_pp.layers[7].weight, model_base.layers[7].weight)
|
||||||
# param_base: [0, 1, 2, 3, 4, 5, 6, 7];
|
assert_close(model_pp.layers[7].weight.grad, model_base.layers[7].weight.grad)
|
||||||
# params pp: [0, 1];
|
assert_close(optim_pp_state[7]["momentum_buffer"], optim_base_state[7]["momentum_buffer"])
|
||||||
assert val_base[:2] == val_pp
|
if rank == 1:
|
||||||
|
# layer 1
|
||||||
|
assert_close(model_pp.layers[1].weight, model_base.layers[1].weight)
|
||||||
|
assert_close(model_pp.layers[1].weight.grad, model_base.layers[1].weight.grad)
|
||||||
|
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[1]["momentum_buffer"])
|
||||||
|
# layer 6
|
||||||
|
assert_close(model_pp.layers[6].weight, model_base.layers[6].weight)
|
||||||
|
assert_close(model_pp.layers[6].weight.grad, model_base.layers[6].weight.grad)
|
||||||
|
assert_close(optim_pp_state[6]["momentum_buffer"], optim_base_state[6]["momentum_buffer"])
|
||||||
|
if rank == 2:
|
||||||
|
# layer 2
|
||||||
|
assert_close(model_pp.layers[2].weight, model_base.layers[2].weight)
|
||||||
|
assert_close(model_pp.layers[2].weight.grad, model_base.layers[2].weight.grad)
|
||||||
|
assert_close(optim_pp_state[2]["momentum_buffer"], optim_base_state[2]["momentum_buffer"])
|
||||||
|
# layer 5
|
||||||
|
assert_close(model_pp.layers[5].weight, model_base.layers[5].weight)
|
||||||
|
assert_close(model_pp.layers[5].weight.grad, model_base.layers[5].weight.grad)
|
||||||
|
assert_close(optim_pp_state[5]["momentum_buffer"], optim_base_state[5]["momentum_buffer"])
|
||||||
|
if rank == 3:
|
||||||
|
# layer 3
|
||||||
|
assert_close(model_pp.layers[3].weight, model_base.layers[3].weight)
|
||||||
|
assert_close(model_pp.layers[3].weight.grad, model_base.layers[3].weight.grad)
|
||||||
|
assert_close(optim_pp_state[3]["momentum_buffer"], optim_base_state[3]["momentum_buffer"])
|
||||||
|
# layer 4
|
||||||
|
assert_close(model_pp.layers[4].weight, model_base.layers[4].weight)
|
||||||
|
assert_close(model_pp.layers[4].weight.grad, model_base.layers[4].weight.grad)
|
||||||
|
assert_close(optim_pp_state[4]["momentum_buffer"], optim_base_state[4]["momentum_buffer"])
|
||||||
|
|
||||||
# assert state
|
# assert optim param_groups
|
||||||
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"])
|
assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
|
||||||
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"])
|
|
||||||
|
|
||||||
|
|
||||||
# TODO:4) support Hybrid base 3)
|
# TODO:4) support Hybrid base 3)
|
||||||
|
|
Loading…
Reference in New Issue