mirror of https://github.com/hpcaitech/ColossalAI
[zero] fixed api consistency (#1098)
parent
cb18922c47
commit
14e5b11d7f
|
@ -73,11 +73,11 @@ class GeminiManager:
|
||||||
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str):
|
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str):
|
||||||
cuda_demand = 0
|
cuda_demand = 0
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if chunk.device_type == 'cpu' or chunk.is_free:
|
if chunk.device_type == 'cpu' or chunk.is_empty:
|
||||||
cuda_demand += chunk.mem
|
cuda_demand += chunk.mem
|
||||||
can_evict_chunks = []
|
can_evict_chunks = []
|
||||||
for chunk in self._chunk_manager.chunk_groups[group_name]:
|
for chunk in self._chunk_manager.chunk_groups[group_name]:
|
||||||
if not chunk.is_free and chunk.device_type == 'cuda' and chunk.can_move_device:
|
if not chunk.is_empty and chunk.device_type == 'cuda' and chunk.can_move_device:
|
||||||
can_evict_chunks.append(chunk)
|
can_evict_chunks.append(chunk)
|
||||||
return cuda_demand, can_evict_chunks
|
return cuda_demand, can_evict_chunks
|
||||||
|
|
||||||
|
|
|
@ -136,7 +136,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
fp32_params_used_cuda_margin_mem = 0
|
fp32_params_used_cuda_margin_mem = 0
|
||||||
for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'],
|
for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'],
|
||||||
self.chunk_manager.chunk_groups['fp32_param']):
|
self.chunk_manager.chunk_groups['fp32_param']):
|
||||||
if fp32_param_chunk.is_free:
|
if fp32_param_chunk.is_empty:
|
||||||
continue
|
continue
|
||||||
if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem:
|
if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem:
|
||||||
self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device())
|
self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device())
|
||||||
|
|
Loading…
Reference in New Issue