mirror of https://github.com/hpcaitech/ColossalAI
[feat] support no_tp Linear for sharderformer.llama
parent
8e40087633
commit
4fc92aa77d
|
@ -64,10 +64,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
|
||||
# P2PMeta cache
|
||||
self.enable_metadata_cache = enable_metadata_cache
|
||||
self.send_tensor_metadata = True
|
||||
self.send_grad_metadata = True
|
||||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
self.send_tensor_metadata = [True, True]
|
||||
self.send_grad_metadata = [True, True]
|
||||
# meta cache buffer
|
||||
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
|
||||
self.grad_metadata_recv = [None, None]
|
||||
|
||||
# P2P communication
|
||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
||||
|
@ -235,10 +236,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
input_tensor, wait_handles = self.comm.recv_forward(
|
||||
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv
|
||||
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
# return input_tensor, wait_handles
|
||||
return wait_handles
|
||||
|
@ -259,10 +260,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
input_tensor, wait_handles = self.comm.recv_forward(
|
||||
next_rank, metadata_recv=self.tensor_metadata_recv
|
||||
next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
# return input_tensor, wait_handles
|
||||
return wait_handles
|
||||
|
@ -297,10 +298,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
else:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||
next_rank, metadata_recv=self.grad_metadata_recv
|
||||
next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
# return output_tensor_grad, wait_handles
|
||||
return wait_handles
|
||||
|
@ -322,10 +323,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
else:
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv
|
||||
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
|
||||
)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
# return output_tensor_grad, wait_handles
|
||||
return wait_handles
|
||||
|
@ -359,9 +360,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
next_rank = self.stage_manager.get_next_rank()
|
||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_forward(
|
||||
output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata
|
||||
output_object=output_tensor,
|
||||
next_rank=next_rank,
|
||||
send_metadata=self.send_tensor_metadata[model_chunk_id],
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
else:
|
||||
|
@ -380,9 +383,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_forward(
|
||||
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata
|
||||
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id]
|
||||
)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
||||
|
@ -415,9 +418,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
|
||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
# bwd chunk1 is left V;
|
||||
|
@ -437,9 +440,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
next_rank = self.stage_manager.get_next_rank()
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata
|
||||
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
def forward_step(
|
||||
|
@ -662,6 +665,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
# print(f"stage {self.stage_manager.stage}; model_chunk_id {model_chunk_id}; output_obj {output_obj};")
|
||||
|
||||
# Step3:
|
||||
# 3-1:detach output; detach output for send fwd;
|
||||
|
@ -886,6 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||
for it in range(len(schedule)):
|
||||
scheduled_node = schedule[it]
|
||||
# print(f"rank {torch.distributed.get_rank()}; stage {self.stage_manager.stage}; scheduled_node {scheduled_node};")
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
|
|
|
@ -191,7 +191,6 @@ class LlamaPipelineForwards:
|
|||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
|
|
@ -9,6 +9,7 @@ from colossalai.shardformer.layer import (
|
|||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
LinearWithGradAccum,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
RMSNorm,
|
||||
|
@ -104,7 +105,7 @@ class LlamaPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
|
||||
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
num_q_heads % tp_size == 0
|
||||
|
@ -191,6 +192,84 @@ class LlamaPolicy(Policy):
|
|||
],
|
||||
)
|
||||
|
||||
# not enable tp, replace layer to LinearWithGradAccum
|
||||
else:
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // tp_size,
|
||||
"self_attn.num_heads": num_q_heads,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
|
||||
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
|
@ -416,6 +495,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||
policy = super().module_policy()
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
|
@ -434,6 +514,25 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
# enable tp, replace layer to LinearWithGradAccum
|
||||
else:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
LlamaForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
# to be confirmed
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
|
|
|
@ -163,8 +163,6 @@ def main():
|
|||
enable_async_reduce=not args.disable_async_reduce,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -179,8 +177,6 @@ def main():
|
|||
enable_flash_attention=args.xformers,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "fsdp":
|
||||
if use_empty_init:
|
||||
|
@ -192,7 +188,6 @@ def main():
|
|||
),
|
||||
param_init_fn=empty_init(),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(
|
||||
|
@ -214,7 +209,6 @@ def main():
|
|||
cpu_offload=CPUOffload(offload_params=True),
|
||||
param_init_fn=empty_init(),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(
|
||||
|
@ -225,7 +219,6 @@ def main():
|
|||
),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
if args.pp_style == "zbv":
|
||||
|
|
|
@ -758,11 +758,13 @@ def run_with_hybridplugin(test_config):
|
|||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
# (0, 1, 4, 1, 1),
|
||||
# # Pass
|
||||
(1, 2, 1, 1, 2),
|
||||
# TODO: adapt mixtral with no TP Linear
|
||||
# (1, 2, 2, 1, 1),
|
||||
(1, 1, 2, 2, 1),
|
||||
# (0, 1, 4, 1, 1),
|
||||
# (1, 1, 2, 2, 1),
|
||||
# (1, 2, 1, 2, 1),
|
||||
# (1, 2, 1, 1, 2),
|
||||
],
|
||||
)
|
||||
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||
|
@ -910,7 +912,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||
p.grad /= dp_size
|
||||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
print(f"rank {dist.get_rank()} config {test_config} test passed")
|
||||
clear_layout_converter()
|
||||
|
@ -921,11 +922,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
(1, 2, 2, 1), # Pass
|
||||
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
|
||||
# (0, 4, 1, 1),
|
||||
# (1, 2, 1, 2),
|
||||
# (1, 1, 2, 2),
|
||||
# # Pass
|
||||
(1, 2, 2, 1),
|
||||
(1, 2, 1, 2),
|
||||
(1, 1, 2, 2),
|
||||
# TODO: acc err in pp4
|
||||
(1, 4, 1, 1),
|
||||
],
|
||||
)
|
||||
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||
|
|
Loading…
Reference in New Issue