mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix gathering output when using tensor parallelism (#5431)
* fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revertpull/5477/merge
parent
f2e8b9ef9f
commit
5e16bf7980
|
@ -199,7 +199,12 @@ def get_param_info(optim: Optimizer):
|
|||
|
||||
if optim is None:
|
||||
return {}
|
||||
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
|
||||
param_info = {
|
||||
"param_groups": [],
|
||||
"param2id": {},
|
||||
"id2param": {},
|
||||
"param2shape": {},
|
||||
}
|
||||
start_index = 0
|
||||
for group in optim.param_groups:
|
||||
packed_group = {k: v for k, v in group.items() if k != "params"}
|
||||
|
@ -899,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
|
||||
|
@ -939,6 +945,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
parallel_output: bool = True,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
|
@ -1035,6 +1042,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
parallel_output=parallel_output,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
|
|
|
@ -25,6 +25,7 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar
|
|||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer._operation import gather_forward_split_backward
|
||||
|
||||
|
||||
class GPT2PipelineForwards:
|
||||
|
@ -337,6 +338,9 @@ class GPT2PipelineForwards:
|
|||
else:
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -793,11 +797,12 @@ def get_gpt2_flash_attention_forward():
|
|||
scale = scale * (1 / float(self.layer_idx + 1))
|
||||
|
||||
# use coloattention
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
|
||||
)
|
||||
if not hasattr(self, "attention"):
|
||||
self.attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
|
||||
)
|
||||
|
||||
attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
|
||||
attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
@ -1083,6 +1088,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
else:
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
|
|
@ -16,7 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer._operation import _gather
|
||||
from ..layer._operation import gather_forward_split_backward
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
|
@ -290,7 +290,7 @@ class LlamaPipelineForwards:
|
|||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -485,8 +485,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
|
|||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = attention(
|
||||
if not hasattr(self, "attention"):
|
||||
self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = self.attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
|
@ -593,7 +594,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
@ -242,4 +242,4 @@ class Policy(ABC):
|
|||
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
|
||||
stage_indices.append([start_idx, end_idx])
|
||||
|
||||
return stage_indices[0] if num_model_chunks == 1 else stage_indices
|
||||
return stage_indices[0] if num_model_chunks == 1 else stage_indices
|
|
@ -34,8 +34,10 @@ class ShardConfig:
|
|||
enable_all_optimization: bool = False
|
||||
enable_sequence_parallelism: bool = False
|
||||
enable_sequence_overlap: bool = False
|
||||
parallel_output = True
|
||||
parallel_output: bool = True
|
||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
# TODO padding vocab
|
||||
# make_vocab_size_divisible_by: int = 128
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
|
|
|
@ -260,7 +260,7 @@ def run_grad_acc_test(test_args):
|
|||
origin_model, origin_optimizer, dataloader=dataloader
|
||||
)
|
||||
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
|
||||
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
|
||||
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
|
|
Loading…
Reference in New Issue