From 0d212183c47cbc09f9d7cb32006057bc3694e370 Mon Sep 17 00:00:00 2001 From: HELSON Date: Wed, 10 Aug 2022 11:37:28 +0800 Subject: [PATCH] [zero] add has_inf_or_nan in AgChunk; enhance the unit test of AgChunk (#1426) --- colossalai/gemini/ag_chunk.py | 49 ++++++++++++++++++++--- tests/test_gemini/chunk/test_agchunk.py | 52 ++++++++++++++++++++++--- 2 files changed, 90 insertions(+), 11 deletions(-) diff --git a/colossalai/gemini/ag_chunk.py b/colossalai/gemini/ag_chunk.py index 0b8a47667..cdeb78222 100644 --- a/colossalai/gemini/ag_chunk.py +++ b/colossalai/gemini/ag_chunk.py @@ -1,6 +1,6 @@ import torch import torch.distributed as dist -from typing import Optional, Dict +from typing import Optional, Dict, List from colossalai.utils import get_current_device from colossalai.tensor import ProcessGroup as ColoProcessGroup @@ -45,10 +45,11 @@ class AgChunk: self.shard_size = chunk_size // self.pg_size self.shard_begin = self.shard_size * self.pg_rank self.shard_end = self.shard_begin + self.shard_size + self.valid_end = self.shard_size self.dtype = dtype device = init_device or get_current_device() - self.chunk_temp = torch.empty(chunk_size, dtype=dtype, device=device) + self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero self.chunk_total = None # we force chunk_total located in CUDA self.cuda_shard = None # using two attributes for the better interpretation self.cpu_shard = None @@ -114,7 +115,7 @@ class AgChunk: if self.chunk_temp is not None: return self.chunk_temp.device.type else: - if self.chunk_total is not None: + if self.is_gathered: return 'cuda' elif self.cuda_shard is not None: return 'cuda' @@ -153,6 +154,12 @@ class AgChunk: # sanity check assert self.chunk_temp is not None + # calculate the valid end for each shard + if self.utilized_size <= self.shard_begin: + self.valid_end = 0 + elif self.utilized_size < self.shard_end: + self.valid_end = self.utilized_size - self.shard_begin + if self.chunk_temp.device.type == 'cpu': self.chunk_total = self.chunk_temp.to(get_current_device()) else: @@ -257,7 +264,7 @@ class AgChunk: self.shard_size, dtype=self.dtype, device=get_current_device()) input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0)) - dist.reduce_scatter(self.cuda_shard, input_list, self.torch_pg) + dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) free_storage(self.chunk_total) self.is_gathered = False @@ -298,17 +305,38 @@ class AgChunk: assert self.is_gathered tensor_info = self.tensors_info[tensor] - self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten()) + self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) + @property + def can_move(self) -> bool: + return not self.is_gathered + @property def can_release(self) -> bool: - return self.tensors_state_monitor[TensorState.HOLD] == self.num_tensors + if self.keep_gathered: + return False + else: + return self.tensors_state_monitor[TensorState.HOLD] + \ + self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors @property def can_reduce(self): return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors + @property + def has_inf_or_nan(self) -> bool: + """ + Check if the chunk has inf or nan values in CUDA. + """ + if self.is_gathered: + valid_tensor = self.chunk_total[: self.utilized_size] + else: + assert self.cuda_shard is not None # only check in CUDA + valid_tensor = self.cuda_shard[: self.valid_end] + + return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() + def __gather(self): if not self.is_gathered: # sanity check @@ -375,6 +403,12 @@ class AgChunk: if prev_state is None or tensor_info.state == prev_state: self.__update_one_tensor_info(tensor_info, next_state) + def __hash__(self) -> int: + return hash(id(self)) + + def __eq__(self, __o: object) -> bool: + return self is __o + def __repr__(self, detailed: bool = False): output = [ "AgChunk Information:\n", @@ -413,3 +447,6 @@ class AgChunk: output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st])) return ''.join(output) + + def get_tensors(self) -> List[torch.Tensor]: + return list(self.tensors_info.keys()) diff --git a/tests/test_gemini/chunk/test_agchunk.py b/tests/test_gemini/chunk/test_agchunk.py index d97a0200e..005c6503b 100644 --- a/tests/test_gemini/chunk/test_agchunk.py +++ b/tests/test_gemini/chunk/test_agchunk.py @@ -2,16 +2,24 @@ import torch import colossalai import pytest import torch.multiprocessing as mp +import torch.distributed as dist from functools import partial from colossalai.testing import rerun_if_address_is_in_use, parameterize from colossalai.utils import free_port, get_current_device from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ColoParameter +from colossalai.gemini import TensorState from colossalai.gemini.ag_chunk import AgChunk +def dist_sum(x): + temp = torch.tensor([x], device=get_current_device()) + dist.all_reduce(temp) + return temp.item() + + def add_param(param_list, param_cp_list, *args, **kwargs): - param = ColoParameter(torch.empty(*args, **kwargs)) + param = ColoParameter(torch.randn(*args, **kwargs)) param_list.append(param) param_cp_list.append(param.clone()) @@ -27,7 +35,7 @@ def check_euqal(param, param_cp): @parameterize('init_device', [None, torch.device('cpu')]) @parameterize('keep_gathered', [True, False]) @parameterize('pin_memory', [True, False]) -def exam_chunk_init(init_device, keep_gathered, pin_memory): +def exam_chunk_basic(init_device, keep_gathered, pin_memory): world_size = torch.distributed.get_world_size() pg = ColoProcessGroup() my_chunk = AgChunk( @@ -56,17 +64,51 @@ def exam_chunk_init(init_device, keep_gathered, pin_memory): if keep_gathered is False: assert my_chunk.cpu_shard.size(0) == 1024 // world_size + assert my_chunk.device_type == 'cpu' + assert my_chunk.can_move my_chunk.shard_move(get_current_device()) + else: + assert my_chunk.chunk_total.size(0) == 1024 + assert my_chunk.device_type == 'cuda' + assert not my_chunk.can_move + + assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size + flag = my_chunk.has_inf_or_nan + assert not flag, "has_inf_or_nan is {}".format(flag) my_chunk.access_chunk() - + assert my_chunk.device_type == 'cuda' for param, param_cp in zip(param_list, param_cp_list): check_euqal(param, param_cp) + assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4 + my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE) + assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3 + assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1 + assert not my_chunk.can_release + + for param in param_list: + my_chunk.tensor_trans_state(param, TensorState.COMPUTE) + my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) + + assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4 + assert my_chunk.can_reduce + my_chunk.reduce() + assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4 + + if keep_gathered is False: + assert my_chunk.cuda_shard.size(0) == 1024 // world_size + assert my_chunk.device_type == 'cuda' + assert my_chunk.can_move + else: + assert my_chunk.chunk_total.size(0) == 1024 + assert my_chunk.device_type == 'cuda' + assert not my_chunk.can_move + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_chunk_init() + exam_chunk_basic() @pytest.mark.dist @@ -78,4 +120,4 @@ def test_chunk_function(world_size): if __name__ == '__main__': - test_chunk_function(2) + test_chunk_function(4)