mirror of https://github.com/InternLM/InternLM
				
				
				
			fix the type_ids when micro_num=1 and use_flash_attn=False (#516)
							parent
							
								
									112c34ae09
								
							
						
					
					
						commit
						809ad9ebc8
					
				| 
						 | 
				
			
			@ -24,13 +24,16 @@ def get_dataset_type_id(dataset_type_ids_map, path):
 | 
			
		|||
    return match_idxes[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def unpack_data(input_ids, cu_seqlens):
 | 
			
		||||
    """
 | 
			
		||||
    input_ids: (n, packed_length)
 | 
			
		||||
    Return:
 | 
			
		||||
    output: (batch_size, max_length)
 | 
			
		||||
def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False):
 | 
			
		||||
    """
 | 
			
		||||
    input_ids: if input_ids is not type_ids, the shape is (1, packed_length)
 | 
			
		||||
               else the shape is (micro_num, packed_length)
 | 
			
		||||
    is_type_ids: whether the input_ids is type_ids
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
    output: if input_ids is not type ids, the shape is (micro_bsz, max_length)
 | 
			
		||||
            else the shape is (micro_num, micro_bsz, max_length)
 | 
			
		||||
    """
 | 
			
		||||
    bsz = input_ids.shape[0]
 | 
			
		||||
 | 
			
		||||
    num_sequence = gpc.config.data["micro_bsz"]
 | 
			
		||||
| 
						 | 
				
			
			@ -45,7 +48,8 @@ def unpack_data(input_ids, cu_seqlens):
 | 
			
		|||
            output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
 | 
			
		||||
        outputs[i] = output
 | 
			
		||||
 | 
			
		||||
    if bsz == 1:
 | 
			
		||||
    # if the input_ids is not type_ids, we need squeeze the first dimension if it is 1.
 | 
			
		||||
    if bsz == 1 and not is_type_ids:
 | 
			
		||||
        outputs = outputs.squeeze(0)
 | 
			
		||||
 | 
			
		||||
    return outputs
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -368,7 +368,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
 | 
			
		|||
    if batch[0].get("type_ids", None) is not None:
 | 
			
		||||
        # if use_flash_attn is False, we need to unpack type_ids
 | 
			
		||||
        if not gpc.config.model.use_flash_attn:
 | 
			
		||||
            batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"])
 | 
			
		||||
            batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True)
 | 
			
		||||
 | 
			
		||||
    return batch, train_iter
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue