mirror of https://github.com/InternLM/InternLM
				
				
				
			fix(ci): fix ci train error (#199)
							parent
							
								
									ef851d16c6
								
							
						
					
					
						commit
						db13bc46bc
					
				|  | @ -547,6 +547,8 @@ def send_backward_and_recv_next_backward_async( | |||
| 
 | ||||
| 
 | ||||
| class AsynCommunicator: | ||||
|     """AsynCommunicator for managing async communication.""" | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         tensor_to_send: Union[torch.Tensor, List[torch.Tensor]], | ||||
|  |  | |||
|  | @ -1,6 +1,5 @@ | |||
| # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication | ||||
| 
 | ||||
| from functools import wraps | ||||
| from typing import List, Tuple, Union | ||||
| 
 | ||||
| import torch | ||||
|  |  | |||
|  | @ -139,7 +139,7 @@ class PipelineScheduler(BaseScheduler): | |||
|             and gpc.is_initialized(ParallelMode.TENSOR) | ||||
|             and gpc.get_world_size(ParallelMode.TENSOR) > 1 | ||||
|         ) | ||||
|          | ||||
| 
 | ||||
|         if gpc.config.model.sequence_parallel: | ||||
|             self.scatter_gather_tensors = False | ||||
| 
 | ||||
|  |  | |||
|  | @ -139,7 +139,7 @@ def args_sanity_check(): | |||
|                     gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/") | ||||
| 
 | ||||
|     if "snapshot_ckpt_folder" not in gpc.config.ckpt: | ||||
|         gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot") | ||||
|         gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder, "snapshot")) | ||||
| 
 | ||||
|     if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"): | ||||
|         gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2) | ||||
|  |  | |||
|  | @ -247,8 +247,12 @@ class ParameterStore(BaseStore): | |||
| 
 | ||||
|     def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False): | ||||
|         if not last_bucket: | ||||
|             if group_id not in self._former_bucket_reduced_param: | ||||
|                 return [], [] | ||||
|             return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id] | ||||
|         else: | ||||
|             if group_id not in self._last_bucket_reduced_param: | ||||
|                 return [], [] | ||||
|             return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id] | ||||
| 
 | ||||
|     def reset_reduced_data_for_compute_norm(self): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 huangting4201
						huangting4201