mirror of https://github.com/InternLM/InternLM
				
				
				
			fix(*): move sequence_parallel to parallel config (#224)
* move sequence_parallel to parallel config * set the sequece_parallel default value is False * fix lint * fix lint * fix lintpull/225/head
							parent
							
								
									32664328e7
								
							
						
					
					
						commit
						a017cab4b3
					
				|  | @ -125,7 +125,6 @@ model = dict( | ||||||
|     layer_norm_epsilon=1e-5, |     layer_norm_epsilon=1e-5, | ||||||
|     use_flash_attn=True, |     use_flash_attn=True, | ||||||
|     num_chunks=1,  # if num_chunks > 1, interleaved pipeline scheduler is used. |     num_chunks=1,  # if num_chunks > 1, interleaved pipeline scheduler is used. | ||||||
|     sequence_parallel=False, |  | ||||||
| ) | ) | ||||||
| """ | """ | ||||||
| zero1 parallel: | zero1 parallel: | ||||||
|  | @ -142,6 +141,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node. | ||||||
| parallel = dict( | parallel = dict( | ||||||
|     zero1=8, |     zero1=8, | ||||||
|     pipeline=dict(size=1, interleaved_overlap=True), |     pipeline=dict(size=1, interleaved_overlap=True), | ||||||
|  |     sequence_parallel=False, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| cudnn_deterministic = False | cudnn_deterministic = False | ||||||
|  |  | ||||||
|  | @ -265,11 +265,13 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'" | ||||||
|     # process the model config |     # process the model config | ||||||
|     if "use_flash_attn" not in gpc.config.model: |     if "use_flash_attn" not in gpc.config.model: | ||||||
|         gpc.config.model._add_item("use_flash_attn", True) |         gpc.config.model._add_item("use_flash_attn", True) | ||||||
|     if "sequence_parallel" not in gpc.config.model: | 
 | ||||||
|         gpc.config.model._add_item("sequence_parallel", False) |     # process the parallel config | ||||||
|  |     if "sequence_parallel" not in gpc.config.parallel: | ||||||
|  |         gpc.config.parallel._add_item("sequence_parallel", False) | ||||||
|     else: |     else: | ||||||
|         assert not ( |         assert not ( | ||||||
|             gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False |             gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False | ||||||
|         ), "sequence parallel does not support use_flash_attn=False" |         ), "sequence parallel does not support use_flash_attn=False" | ||||||
| 
 | 
 | ||||||
|     # feishu webhook address for alerting |     # feishu webhook address for alerting | ||||||
|  |  | ||||||
|  | @ -56,7 +56,7 @@ class Embedding1D(nn.Module): | ||||||
| 
 | 
 | ||||||
|         output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1) |         output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1) | ||||||
| 
 | 
 | ||||||
|         if gpc.config.model.sequence_parallel: |         if gpc.config.parallel.sequence_parallel: | ||||||
|             output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1) |             output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1) | ||||||
| 
 | 
 | ||||||
|         return output |         return output | ||||||
|  |  | ||||||
|  | @ -62,7 +62,7 @@ class ScaleColumnParallelLinear(nn.Linear): | ||||||
|             weight, |             weight, | ||||||
|             self.bias, |             self.bias, | ||||||
|             process_group=self.process_group, |             process_group=self.process_group, | ||||||
|             sequence_parallel=gpc.config.model.sequence_parallel, |             sequence_parallel=gpc.config.parallel.sequence_parallel, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -111,7 +111,7 @@ class RewardModelLinear(ScaleColumnParallelLinear): | ||||||
|             weight, |             weight, | ||||||
|             self.bias, |             self.bias, | ||||||
|             process_group=self.process_group, |             process_group=self.process_group, | ||||||
|             sequence_parallel=gpc.config.model.sequence_parallel, |             sequence_parallel=gpc.config.parallel.sequence_parallel, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -173,7 +173,7 @@ class FeedForward(nn.Module): | ||||||
|             hidden_features, |             hidden_features, | ||||||
|             process_group, |             process_group, | ||||||
|             bias, |             bias, | ||||||
|             sequence_parallel=gpc.config.model.sequence_parallel, |             sequence_parallel=gpc.config.parallel.sequence_parallel, | ||||||
|             device=device, |             device=device, | ||||||
|             dtype=dtype, |             dtype=dtype, | ||||||
|         ) |         ) | ||||||
|  | @ -182,7 +182,7 @@ class FeedForward(nn.Module): | ||||||
|             hidden_features, |             hidden_features, | ||||||
|             process_group, |             process_group, | ||||||
|             bias, |             bias, | ||||||
|             sequence_parallel=gpc.config.model.sequence_parallel, |             sequence_parallel=gpc.config.parallel.sequence_parallel, | ||||||
|             device=device, |             device=device, | ||||||
|             dtype=dtype, |             dtype=dtype, | ||||||
|         ) |         ) | ||||||
|  | @ -191,7 +191,7 @@ class FeedForward(nn.Module): | ||||||
|             out_features, |             out_features, | ||||||
|             process_group, |             process_group, | ||||||
|             bias=bias, |             bias=bias, | ||||||
|             sequence_parallel=gpc.config.model.sequence_parallel, |             sequence_parallel=gpc.config.parallel.sequence_parallel, | ||||||
|             device=device, |             device=device, | ||||||
|             dtype=dtype, |             dtype=dtype, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  | @ -121,7 +121,7 @@ class PackedFlashBaseLayer1D(nn.Module): | ||||||
|                 process_group=gpc.get_group(ParallelMode.TENSOR), |                 process_group=gpc.get_group(ParallelMode.TENSOR), | ||||||
|                 bias1=False, |                 bias1=False, | ||||||
|                 bias2=False, |                 bias2=False, | ||||||
|                 sequence_parallel=gpc.config.model.sequence_parallel, |                 sequence_parallel=gpc.config.parallel.sequence_parallel, | ||||||
|                 checkpoint_lvl=0, |                 checkpoint_lvl=0, | ||||||
|                 heuristic="auto", |                 heuristic="auto", | ||||||
|                 device=device, |                 device=device, | ||||||
|  | @ -294,7 +294,7 @@ class PackedFlashInternLm1D(nn.Module): | ||||||
|                     max_position_embeddings=-1, |                     max_position_embeddings=-1, | ||||||
|                     process_group=gpc.get_group(ParallelMode.TENSOR), |                     process_group=gpc.get_group(ParallelMode.TENSOR), | ||||||
|                     padding_idx=None, |                     padding_idx=None, | ||||||
|                     sequence_parallel=gpc.config.model.sequence_parallel, |                     sequence_parallel=gpc.config.parallel.sequence_parallel, | ||||||
|                     device=device, |                     device=device, | ||||||
|                     dtype=dtype, |                     dtype=dtype, | ||||||
|                 ) |                 ) | ||||||
|  |  | ||||||
|  | @ -82,7 +82,7 @@ class MHA(nn.Module): | ||||||
|             3 * embed_dim, |             3 * embed_dim, | ||||||
|             process_group, |             process_group, | ||||||
|             bias=True, |             bias=True, | ||||||
|             sequence_parallel=gpc.config.model.sequence_parallel, |             sequence_parallel=gpc.config.parallel.sequence_parallel, | ||||||
|             **factory_kwargs, |             **factory_kwargs, | ||||||
|         )  # according to https://spaces.ac.cn/archives/9577 |         )  # according to https://spaces.ac.cn/archives/9577 | ||||||
| 
 | 
 | ||||||
|  | @ -95,7 +95,11 @@ class MHA(nn.Module): | ||||||
| 
 | 
 | ||||||
|         # output projection always have the bias (for now) |         # output projection always have the bias (for now) | ||||||
|         self.out_proj = RowParallelLinearTorch( |         self.out_proj = RowParallelLinearTorch( | ||||||
|             embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs |             embed_dim, | ||||||
|  |             embed_dim, | ||||||
|  |             process_group, | ||||||
|  |             sequence_parallel=gpc.config.parallel.sequence_parallel, | ||||||
|  |             **factory_kwargs, | ||||||
|         ) |         ) | ||||||
|         # need to assign tp attribute so that internlm know it is tensor parallel module |         # need to assign tp attribute so that internlm know it is tensor parallel module | ||||||
|         if gpc.get_world_size(ParallelMode.TENSOR) > 1: |         if gpc.get_world_size(ParallelMode.TENSOR) > 1: | ||||||
|  |  | ||||||
|  | @ -52,12 +52,12 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape | ||||||
| 
 | 
 | ||||||
| @contextmanager | @contextmanager | ||||||
| def switch_sequence_parallel_mode(): | def switch_sequence_parallel_mode(): | ||||||
|     prev_mode = gpc.config.model.sequence_parallel |     prev_mode = gpc.config.parallel.sequence_parallel | ||||||
|     try: |     try: | ||||||
|         gpc.config.model.sequence_parallel = False |         gpc.config.parallel.sequence_parallel = False | ||||||
|         yield |         yield | ||||||
|     finally: |     finally: | ||||||
|         gpc.config.model.sequence_parallel = prev_mode |         gpc.config.parallel.sequence_parallel = prev_mode | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def evaluate_on_val_dls( | def evaluate_on_val_dls( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 ytxiong
						ytxiong