diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 45fe03003..fa3c3646a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -999,7 +999,9 @@ class HybridParallelPlugin(PipelinePluginBase): ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" if enable_sequence_parallelism: - self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1" + self.sequence_parallelism_mode = ( + sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" + ) assert ( self.sequence_parallelism_mode in SUPPORT_SP_MODE ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" @@ -1014,19 +1016,13 @@ class HybridParallelPlugin(PipelinePluginBase): self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) elif self.sequence_parallelism_mode in ["all_to_all"]: - assert ( - tp_size == 1 - ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism" - assert ( - pp_size == 1 - ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism" - self.sp_size = dist.get_world_size() if sp_size is None else sp_size - self.dp_size = dist.get_world_size() // (self.sp_size * pp_size) + self.sp_size = 1 if sp_size is None else sp_size + self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) else: self.dp_size = dist.get_world_size() // (tp_size * pp_size) assert ( sp_size == 1 or sp_size is None - ), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True" + ), f"You should not set sp_size when sequence parallelism is not enabled." self.sp_size = 1 self.tp_size = tp_size @@ -1040,11 +1036,22 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism if dp_outside: - self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + ( + self.dp_axis, + self.pp_axis, + self.tp_axis, + self.sp_axis, + ) = ( + 0, + 1, + 2, + 3, + ) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) else: self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) + self.stage_manager = None self.schedule = None self.custom_policy = custom_policy diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f47be48ee..bf5ce45a8 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -18,6 +18,7 @@ from transformers.models.llama.modeling_llama import ( LlamaForSequenceClassification, LlamaModel, StaticCache, + apply_rotary_pos_emb, repeat_kv, ) from transformers.utils import logging @@ -459,37 +460,53 @@ class LlamaPipelineForwards: return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): - from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - - try: - from transformers.models.llama.modeling_llama import repeat_kv - except: - warnings.warn("using llamav1, llamav1 hasn't repeat_kv function") - +def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def forward( - self: LlamaAttention, + self, hidden_states: torch.Tensor, - attention_mask: Optional[dict] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - bsz, q_len, _ = hidden_states.size() + bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring if sp_mode in ["split_gather", "ring"]: q_len *= sp_size - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": @@ -520,30 +537,66 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." - attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + if shard_config.enable_flash_attention: + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) - attn_output = self.o_proj(attn_output) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - return attn_output, None, past_key_value + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value return forward -def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): +def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) - assert shard_config.enable_flash_attention, "Flash Attention is not enabled." def forward( - self: LlamaModel, + self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -560,7 +613,6 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds @@ -569,16 +621,18 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 + seq_len = inputs_embeds.shape[1] if use_cache: # kept for BC (cache positions) if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) @@ -586,32 +640,29 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) - # embed positions - hidden_states = inputs_embeds - # in this case, attention_mask is a dict rather than a tensor - mask_shape = (hidden_states.shape[0], 1, past_seen_tokens, past_seen_tokens) - attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) + if shard_config.enable_flash_attention: + mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None @@ -621,7 +672,6 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -633,6 +683,7 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): use_cache, cache_position, ) + else: layer_outputs = decoder_layer( hidden_states, @@ -654,6 +705,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): hidden_states = self.norm(hidden_states) + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -665,6 +721,7 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -778,240 +835,3 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) return forward - - -def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): - from transformers.models.llama.modeling_llama import apply_rotary_pos_emb - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - # sp: modify sp_len when sequence parallel mode is ring - if sp_mode in ["split_gather", "ring"]: - q_len *= sp_size - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) - bsz, q_len, _ = query_states.size() - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - past_key_value = getattr(self, "past_key_value", past_key_value) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - # sp: all-to-all comminucation when introducing sequence parallel - if sp_mode == "all_to_all": - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) - else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value - - return forward - - -def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): - logger = logging.get_logger(__name__) - - def forward( - self: LlamaModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() - if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) - - if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) - elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value=past_key_values, output_attentions=output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - ) - - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, Cache) - else next_decoder_cache - ) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) - elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - return forward diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index a9c982231..5852713c2 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -20,9 +20,7 @@ from colossalai.shardformer.layer import ( from ..modeling.llama import ( LlamaPipelineForwards, get_llama_flash_attention_forward, - get_llama_model_forward_for_flash_attn, - get_llama_seq_parallel_attention_forward, - get_llama_seq_parallel_model_forward, + get_llama_flash_attention_model_forward, get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -82,33 +80,7 @@ class LlamaPolicy(Policy): ) sp_partial_derived = sp_mode in ["split_gather", "ring"] - use_flash_attention = self.shard_config.enable_flash_attention - # Currently sp cannot to be used with flashattention - if sp_mode in ["split_gather", "ring", "all_to_all"]: - if use_flash_attention: - warnings.warn( - f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically." - ) - use_flash_attention = False - - if sp_mode in ["split_gather", "ring"]: - self.append_or_create_method_replacement( - description={ - "forward": get_llama_seq_parallel_model_forward( - sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group - ), - }, - policy=policy, - target_key=LlamaModel, - ) - self.append_or_create_method_replacement( - description={ - "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), - }, - policy=policy, - target_key=attn_cls, - ) - elif sp_mode == "all_to_all": + if sp_mode == "all_to_all": decoder_attribute_replacement = { "num_heads": self.model.config.num_attention_heads // sp_size, } @@ -118,24 +90,27 @@ class LlamaPolicy(Policy): policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) + if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, target_key=attn_cls, ) - self.append_or_create_method_replacement( - description={ - "forward": get_llama_seq_parallel_model_forward( - sp_mode=sp_mode, - sp_size=sp_size, - sp_group=sp_group, - ), - }, - policy=policy, - target_key=LlamaModel, - ) + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_flash_attention_model_forward( + self.shard_config, + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=LlamaModel, + ) if self.shard_config.enable_tensor_parallelism: assert ( @@ -235,25 +210,6 @@ class LlamaPolicy(Policy): target_key=LlamaModel, ) - # use flash attention - if use_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), - }, - policy=policy, - target_key=attn_cls, - ) - if self.pipeline_stage_manager is None: - # replace llama model forward method - self.append_or_create_method_replacement( - description={ - "forward": get_llama_model_forward_for_flash_attn(self.shard_config), - }, - policy=policy, - target_key=LlamaModel, - ) - return policy def postprocess(self): diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 8d4dae314..f6c975305 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -72,6 +72,7 @@ def main(): parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size") parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") @@ -174,6 +175,8 @@ def main(): tp_size=args.tp, pp_size=args.pp, zero_stage=args.zero, + sp_size=args.sp, + enable_sequence_parallelism=args.sp > 1, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 1628bf2f3..3a8a1357d 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -120,9 +120,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight( - llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False - ) + try: + check_weight( + llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + except Exception as e: + print(f"Failed config: {test_config}") + raise e # check grads check_all_grad_tensors(grads_to_check) @@ -133,9 +144,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { + { # Test ring + Flash attention "tp_size": 2, "pp_size": 1, + "sp_size": 2, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring", @@ -145,14 +157,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, "use_lazy_init": True, + "zero_stage": 1, "precision": "fp16", "initial_scale": 1, }, @@ -164,7 +178,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", "use_lazy_init": True, - "zero_stage": 2, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, }, @@ -213,7 +238,11 @@ def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index() @@ -263,7 +292,11 @@ def run_llama_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index()