mirror of https://github.com/InternLM/InternLM
				
				
				
			feat(model/overlap_handler.py): optimize reduce scatter mem pool
							parent
							
								
									b20f47a1fe
								
							
						
					
					
						commit
						e7f9f1d208
					
				| 
						 | 
				
			
			@ -125,37 +125,38 @@ class FSTPOverlapHandler:
 | 
			
		|||
 | 
			
		||||
        # if key not in dict
 | 
			
		||||
        if key not in self.reduce_scatter_memory_pool:
 | 
			
		||||
            self.reduce_scatter_memory_pool[key] = {"data": [], "used": []}
 | 
			
		||||
            self.reduce_scatter_memory_pool[key] = []
 | 
			
		||||
 | 
			
		||||
        # if the data is empty
 | 
			
		||||
        if len(self.reduce_scatter_memory_pool[key]["data"]) == 0:
 | 
			
		||||
            self.reduce_scatter_memory_pool[key]["data"].append(
 | 
			
		||||
        if len(self.reduce_scatter_memory_pool[key]) == 0:
 | 
			
		||||
            self.reduce_scatter_memory_pool[key].append(
 | 
			
		||||
                torch.zeros(
 | 
			
		||||
                    key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
 | 
			
		||||
                ).contiguous()
 | 
			
		||||
            )
 | 
			
		||||
            self.reduce_scatter_memory_pool[key]["used"].append(True)
 | 
			
		||||
            return_idx = 0
 | 
			
		||||
            return return_idx
 | 
			
		||||
            setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False)
 | 
			
		||||
            setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
 | 
			
		||||
            return self.reduce_scatter_memory_pool[key][return_idx]
 | 
			
		||||
        else:  # if not empty
 | 
			
		||||
            for index, used in enumerate(self.reduce_scatter_memory_pool[key]["used"]):
 | 
			
		||||
                if used is False:
 | 
			
		||||
                    self.reduce_scatter_memory_pool[key]["used"][index] = True
 | 
			
		||||
            for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]):
 | 
			
		||||
                if mem_item.idle is True:
 | 
			
		||||
                    self.reduce_scatter_memory_pool[key][index].idle = False
 | 
			
		||||
                    return_idx = index
 | 
			
		||||
                    return return_idx
 | 
			
		||||
                    return self.reduce_scatter_memory_pool[key][return_idx]
 | 
			
		||||
            # if the memory pool is all used
 | 
			
		||||
            length = len(self.reduce_scatter_memory_pool[key]["data"])
 | 
			
		||||
            self.reduce_scatter_memory_pool[key]["data"].append(
 | 
			
		||||
            cur_len = len(self.reduce_scatter_memory_pool[key])
 | 
			
		||||
            self.reduce_scatter_memory_pool[key].append(
 | 
			
		||||
                torch.zeros(
 | 
			
		||||
                    key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
 | 
			
		||||
                ).contiguous()
 | 
			
		||||
            )
 | 
			
		||||
            self.reduce_scatter_memory_pool[key]["used"].append(True)
 | 
			
		||||
            return_idx = length
 | 
			
		||||
            return return_idx
 | 
			
		||||
            setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False)
 | 
			
		||||
            return_idx = cur_len
 | 
			
		||||
            setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
 | 
			
		||||
            return self.reduce_scatter_memory_pool[key][return_idx]
 | 
			
		||||
 | 
			
		||||
    def release_reduce_scatter_memory(self, size, index):
 | 
			
		||||
        self.reduce_scatter_memory_pool[size]["used"][index] = False
 | 
			
		||||
    def release_reduce_scatter_memory(self, key, index):
 | 
			
		||||
        self.reduce_scatter_memory_pool[key][index].idle = True
 | 
			
		||||
 | 
			
		||||
    def _all_gather_block_weight_memory_pool(self, block_index: int):
 | 
			
		||||
        fstp_modules = self.index_to_fstp_modules[block_index]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -164,9 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
 | 
			
		|||
    world_size = torch.distributed.get_world_size(process_group)
 | 
			
		||||
    assert input_.shape[0] % world_size == 0
 | 
			
		||||
    size = (input_.shape[0] // world_size, *input_.shape[1:])
 | 
			
		||||
    index = gpc.fstp_handler.get_reduce_scatter_memory(size)
 | 
			
		||||
    output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
 | 
			
		||||
    setattr(output, "index", index)
 | 
			
		||||
    output = gpc.fstp_handler.get_reduce_scatter_memory(size)
 | 
			
		||||
    handle = torch.distributed.reduce_scatter_tensor(
 | 
			
		||||
        output, input_.contiguous(), group=process_group, async_op=async_op
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
            _param.grad.add_(_grad)
 | 
			
		||||
 | 
			
		||||
            # release cuda memory.
 | 
			
		||||
            gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
 | 
			
		||||
            gpc.fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index)
 | 
			
		||||
            self._fstp_handler.reduce_scatter_handlers[_key] = None
 | 
			
		||||
 | 
			
		||||
        bucket.reset_by_rank(reduce_rank)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue