[zero] fixed api consistency (#1098)

pull/1096/head
Frank Lee 2022-06-10 16:59:59 +08:00 committed by GitHub
parent cb18922c47
commit 14e5b11d7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View File

@ -73,11 +73,11 @@ class GeminiManager:
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str):
cuda_demand = 0
for chunk in chunks:
if chunk.device_type == 'cpu' or chunk.is_free:
if chunk.device_type == 'cpu' or chunk.is_empty:
cuda_demand += chunk.mem
can_evict_chunks = []
for chunk in self._chunk_manager.chunk_groups[group_name]:
if not chunk.is_free and chunk.device_type == 'cuda' and chunk.can_move_device:
if not chunk.is_empty and chunk.device_type == 'cuda' and chunk.can_move_device:
can_evict_chunks.append(chunk)
return cuda_demand, can_evict_chunks

View File

@ -136,7 +136,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
fp32_params_used_cuda_margin_mem = 0
for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'],
self.chunk_manager.chunk_groups['fp32_param']):
if fp32_param_chunk.is_free:
if fp32_param_chunk.is_empty:
continue
if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem:
self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device())