mirror of https://github.com/InternLM/InternLM
				
				
				
			fix(ci): fix ci train error (#199)
							parent
							
								
									ef851d16c6
								
							
						
					
					
						commit
						db13bc46bc
					
				| 
						 | 
				
			
			@ -547,6 +547,8 @@ def send_backward_and_recv_next_backward_async(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class AsynCommunicator:
 | 
			
		||||
    """AsynCommunicator for managing async communication."""
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,5 @@
 | 
			
		|||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
 | 
			
		||||
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from typing import List, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -139,7 +139,7 @@ class PipelineScheduler(BaseScheduler):
 | 
			
		|||
            and gpc.is_initialized(ParallelMode.TENSOR)
 | 
			
		||||
            and gpc.get_world_size(ParallelMode.TENSOR) > 1
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        if gpc.config.model.sequence_parallel:
 | 
			
		||||
            self.scatter_gather_tensors = False
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -139,7 +139,7 @@ def args_sanity_check():
 | 
			
		|||
                    gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
 | 
			
		||||
 | 
			
		||||
    if "snapshot_ckpt_folder" not in gpc.config.ckpt:
 | 
			
		||||
        gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot")
 | 
			
		||||
        gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder, "snapshot"))
 | 
			
		||||
 | 
			
		||||
    if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
 | 
			
		||||
        gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -247,8 +247,12 @@ class ParameterStore(BaseStore):
 | 
			
		|||
 | 
			
		||||
    def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False):
 | 
			
		||||
        if not last_bucket:
 | 
			
		||||
            if group_id not in self._former_bucket_reduced_param:
 | 
			
		||||
                return [], []
 | 
			
		||||
            return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
 | 
			
		||||
        else:
 | 
			
		||||
            if group_id not in self._last_bucket_reduced_param:
 | 
			
		||||
                return [], []
 | 
			
		||||
            return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
 | 
			
		||||
 | 
			
		||||
    def reset_reduced_data_for_compute_norm(self):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue