[zero] add has_inf_or_nan in AgChunk; enhance the unit test of AgChunk (#1426)

pull/1424/head^2
HELSON 2022-08-10 11:37:28 +08:00 committed by GitHub
parent 33f0744d51
commit 0d212183c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 90 additions and 11 deletions

View File

@ -1,6 +1,6 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from typing import Optional, Dict from typing import Optional, Dict, List
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
@ -45,10 +45,11 @@ class AgChunk:
self.shard_size = chunk_size // self.pg_size self.shard_size = chunk_size // self.pg_size
self.shard_begin = self.shard_size * self.pg_rank self.shard_begin = self.shard_size * self.pg_rank
self.shard_end = self.shard_begin + self.shard_size self.shard_end = self.shard_begin + self.shard_size
self.valid_end = self.shard_size
self.dtype = dtype self.dtype = dtype
device = init_device or get_current_device() device = init_device or get_current_device()
self.chunk_temp = torch.empty(chunk_size, dtype=dtype, device=device) self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
self.chunk_total = None # we force chunk_total located in CUDA self.chunk_total = None # we force chunk_total located in CUDA
self.cuda_shard = None # using two attributes for the better interpretation self.cuda_shard = None # using two attributes for the better interpretation
self.cpu_shard = None self.cpu_shard = None
@ -114,7 +115,7 @@ class AgChunk:
if self.chunk_temp is not None: if self.chunk_temp is not None:
return self.chunk_temp.device.type return self.chunk_temp.device.type
else: else:
if self.chunk_total is not None: if self.is_gathered:
return 'cuda' return 'cuda'
elif self.cuda_shard is not None: elif self.cuda_shard is not None:
return 'cuda' return 'cuda'
@ -153,6 +154,12 @@ class AgChunk:
# sanity check # sanity check
assert self.chunk_temp is not None assert self.chunk_temp is not None
# calculate the valid end for each shard
if self.utilized_size <= self.shard_begin:
self.valid_end = 0
elif self.utilized_size < self.shard_end:
self.valid_end = self.utilized_size - self.shard_begin
if self.chunk_temp.device.type == 'cpu': if self.chunk_temp.device.type == 'cpu':
self.chunk_total = self.chunk_temp.to(get_current_device()) self.chunk_total = self.chunk_temp.to(get_current_device())
else: else:
@ -257,7 +264,7 @@ class AgChunk:
self.shard_size, dtype=self.dtype, device=get_current_device()) self.shard_size, dtype=self.dtype, device=get_current_device())
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0)) input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, self.torch_pg) dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
free_storage(self.chunk_total) free_storage(self.chunk_total)
self.is_gathered = False self.is_gathered = False
@ -298,17 +305,38 @@ class AgChunk:
assert self.is_gathered assert self.is_gathered
tensor_info = self.tensors_info[tensor] tensor_info = self.tensors_info[tensor]
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten()) self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
@property
def can_move(self) -> bool:
return not self.is_gathered
@property @property
def can_release(self) -> bool: def can_release(self) -> bool:
return self.tensors_state_monitor[TensorState.HOLD] == self.num_tensors if self.keep_gathered:
return False
else:
return self.tensors_state_monitor[TensorState.HOLD] + \
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
@property @property
def can_reduce(self): def can_reduce(self):
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
@property
def has_inf_or_nan(self) -> bool:
"""
Check if the chunk has inf or nan values in CUDA.
"""
if self.is_gathered:
valid_tensor = self.chunk_total[: self.utilized_size]
else:
assert self.cuda_shard is not None # only check in CUDA
valid_tensor = self.cuda_shard[: self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def __gather(self): def __gather(self):
if not self.is_gathered: if not self.is_gathered:
# sanity check # sanity check
@ -375,6 +403,12 @@ class AgChunk:
if prev_state is None or tensor_info.state == prev_state: if prev_state is None or tensor_info.state == prev_state:
self.__update_one_tensor_info(tensor_info, next_state) self.__update_one_tensor_info(tensor_info, next_state)
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
def __repr__(self, detailed: bool = False): def __repr__(self, detailed: bool = False):
output = [ output = [
"AgChunk Information:\n", "AgChunk Information:\n",
@ -413,3 +447,6 @@ class AgChunk:
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st])) output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
return ''.join(output) return ''.join(output)
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())

View File

@ -2,16 +2,24 @@ import torch
import colossalai import colossalai
import pytest import pytest
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist
from functools import partial from functools import partial
from colossalai.testing import rerun_if_address_is_in_use, parameterize from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.gemini import TensorState
from colossalai.gemini.ag_chunk import AgChunk from colossalai.gemini.ag_chunk import AgChunk
def dist_sum(x):
temp = torch.tensor([x], device=get_current_device())
dist.all_reduce(temp)
return temp.item()
def add_param(param_list, param_cp_list, *args, **kwargs): def add_param(param_list, param_cp_list, *args, **kwargs):
param = ColoParameter(torch.empty(*args, **kwargs)) param = ColoParameter(torch.randn(*args, **kwargs))
param_list.append(param) param_list.append(param)
param_cp_list.append(param.clone()) param_cp_list.append(param.clone())
@ -27,7 +35,7 @@ def check_euqal(param, param_cp):
@parameterize('init_device', [None, torch.device('cpu')]) @parameterize('init_device', [None, torch.device('cpu')])
@parameterize('keep_gathered', [True, False]) @parameterize('keep_gathered', [True, False])
@parameterize('pin_memory', [True, False]) @parameterize('pin_memory', [True, False])
def exam_chunk_init(init_device, keep_gathered, pin_memory): def exam_chunk_basic(init_device, keep_gathered, pin_memory):
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ColoProcessGroup() pg = ColoProcessGroup()
my_chunk = AgChunk( my_chunk = AgChunk(
@ -56,17 +64,51 @@ def exam_chunk_init(init_device, keep_gathered, pin_memory):
if keep_gathered is False: if keep_gathered is False:
assert my_chunk.cpu_shard.size(0) == 1024 // world_size assert my_chunk.cpu_shard.size(0) == 1024 // world_size
assert my_chunk.device_type == 'cpu'
assert my_chunk.can_move
my_chunk.shard_move(get_current_device()) my_chunk.shard_move(get_current_device())
else:
assert my_chunk.chunk_total.size(0) == 1024
assert my_chunk.device_type == 'cuda'
assert not my_chunk.can_move
assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size
flag = my_chunk.has_inf_or_nan
assert not flag, "has_inf_or_nan is {}".format(flag)
my_chunk.access_chunk() my_chunk.access_chunk()
assert my_chunk.device_type == 'cuda'
for param, param_cp in zip(param_list, param_cp_list): for param, param_cp in zip(param_list, param_cp_list):
check_euqal(param, param_cp) check_euqal(param, param_cp)
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE)
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3
assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1
assert not my_chunk.can_release
for param in param_list:
my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4
assert my_chunk.can_reduce
my_chunk.reduce()
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
if keep_gathered is False:
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
assert my_chunk.device_type == 'cuda'
assert my_chunk.can_move
else:
assert my_chunk.chunk_total.size(0) == 1024
assert my_chunk.device_type == 'cuda'
assert not my_chunk.can_move
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_chunk_init() exam_chunk_basic()
@pytest.mark.dist @pytest.mark.dist
@ -78,4 +120,4 @@ def test_chunk_function(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_chunk_function(2) test_chunk_function(4)