mirror of https://github.com/hpcaitech/ColossalAI
[zero] add has_inf_or_nan in AgChunk; enhance the unit test of AgChunk (#1426)
parent
33f0744d51
commit
0d212183c4
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict, List
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||||
|
@ -45,10 +45,11 @@ class AgChunk:
|
||||||
self.shard_size = chunk_size // self.pg_size
|
self.shard_size = chunk_size // self.pg_size
|
||||||
self.shard_begin = self.shard_size * self.pg_rank
|
self.shard_begin = self.shard_size * self.pg_rank
|
||||||
self.shard_end = self.shard_begin + self.shard_size
|
self.shard_end = self.shard_begin + self.shard_size
|
||||||
|
self.valid_end = self.shard_size
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
device = init_device or get_current_device()
|
device = init_device or get_current_device()
|
||||||
self.chunk_temp = torch.empty(chunk_size, dtype=dtype, device=device)
|
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
||||||
self.chunk_total = None # we force chunk_total located in CUDA
|
self.chunk_total = None # we force chunk_total located in CUDA
|
||||||
self.cuda_shard = None # using two attributes for the better interpretation
|
self.cuda_shard = None # using two attributes for the better interpretation
|
||||||
self.cpu_shard = None
|
self.cpu_shard = None
|
||||||
|
@ -114,7 +115,7 @@ class AgChunk:
|
||||||
if self.chunk_temp is not None:
|
if self.chunk_temp is not None:
|
||||||
return self.chunk_temp.device.type
|
return self.chunk_temp.device.type
|
||||||
else:
|
else:
|
||||||
if self.chunk_total is not None:
|
if self.is_gathered:
|
||||||
return 'cuda'
|
return 'cuda'
|
||||||
elif self.cuda_shard is not None:
|
elif self.cuda_shard is not None:
|
||||||
return 'cuda'
|
return 'cuda'
|
||||||
|
@ -153,6 +154,12 @@ class AgChunk:
|
||||||
# sanity check
|
# sanity check
|
||||||
assert self.chunk_temp is not None
|
assert self.chunk_temp is not None
|
||||||
|
|
||||||
|
# calculate the valid end for each shard
|
||||||
|
if self.utilized_size <= self.shard_begin:
|
||||||
|
self.valid_end = 0
|
||||||
|
elif self.utilized_size < self.shard_end:
|
||||||
|
self.valid_end = self.utilized_size - self.shard_begin
|
||||||
|
|
||||||
if self.chunk_temp.device.type == 'cpu':
|
if self.chunk_temp.device.type == 'cpu':
|
||||||
self.chunk_total = self.chunk_temp.to(get_current_device())
|
self.chunk_total = self.chunk_temp.to(get_current_device())
|
||||||
else:
|
else:
|
||||||
|
@ -257,7 +264,7 @@ class AgChunk:
|
||||||
self.shard_size, dtype=self.dtype, device=get_current_device())
|
self.shard_size, dtype=self.dtype, device=get_current_device())
|
||||||
|
|
||||||
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
|
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
|
||||||
dist.reduce_scatter(self.cuda_shard, input_list, self.torch_pg)
|
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||||
|
|
||||||
free_storage(self.chunk_total)
|
free_storage(self.chunk_total)
|
||||||
self.is_gathered = False
|
self.is_gathered = False
|
||||||
|
@ -298,17 +305,38 @@ class AgChunk:
|
||||||
assert self.is_gathered
|
assert self.is_gathered
|
||||||
|
|
||||||
tensor_info = self.tensors_info[tensor]
|
tensor_info = self.tensors_info[tensor]
|
||||||
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
|
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
|
||||||
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_move(self) -> bool:
|
||||||
|
return not self.is_gathered
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_release(self) -> bool:
|
def can_release(self) -> bool:
|
||||||
return self.tensors_state_monitor[TensorState.HOLD] == self.num_tensors
|
if self.keep_gathered:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return self.tensors_state_monitor[TensorState.HOLD] + \
|
||||||
|
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_reduce(self):
|
def can_reduce(self):
|
||||||
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_inf_or_nan(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the chunk has inf or nan values in CUDA.
|
||||||
|
"""
|
||||||
|
if self.is_gathered:
|
||||||
|
valid_tensor = self.chunk_total[: self.utilized_size]
|
||||||
|
else:
|
||||||
|
assert self.cuda_shard is not None # only check in CUDA
|
||||||
|
valid_tensor = self.cuda_shard[: self.valid_end]
|
||||||
|
|
||||||
|
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
||||||
|
|
||||||
def __gather(self):
|
def __gather(self):
|
||||||
if not self.is_gathered:
|
if not self.is_gathered:
|
||||||
# sanity check
|
# sanity check
|
||||||
|
@ -375,6 +403,12 @@ class AgChunk:
|
||||||
if prev_state is None or tensor_info.state == prev_state:
|
if prev_state is None or tensor_info.state == prev_state:
|
||||||
self.__update_one_tensor_info(tensor_info, next_state)
|
self.__update_one_tensor_info(tensor_info, next_state)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(id(self))
|
||||||
|
|
||||||
|
def __eq__(self, __o: object) -> bool:
|
||||||
|
return self is __o
|
||||||
|
|
||||||
def __repr__(self, detailed: bool = False):
|
def __repr__(self, detailed: bool = False):
|
||||||
output = [
|
output = [
|
||||||
"AgChunk Information:\n",
|
"AgChunk Information:\n",
|
||||||
|
@ -413,3 +447,6 @@ class AgChunk:
|
||||||
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
|
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
|
||||||
|
|
||||||
return ''.join(output)
|
return ''.join(output)
|
||||||
|
|
||||||
|
def get_tensors(self) -> List[torch.Tensor]:
|
||||||
|
return list(self.tensors_info.keys())
|
||||||
|
|
|
@ -2,16 +2,24 @@ import torch
|
||||||
import colossalai
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
||||||
from colossalai.utils import free_port, get_current_device
|
from colossalai.utils import free_port, get_current_device
|
||||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||||
from colossalai.tensor import ColoParameter
|
from colossalai.tensor import ColoParameter
|
||||||
|
from colossalai.gemini import TensorState
|
||||||
from colossalai.gemini.ag_chunk import AgChunk
|
from colossalai.gemini.ag_chunk import AgChunk
|
||||||
|
|
||||||
|
|
||||||
|
def dist_sum(x):
|
||||||
|
temp = torch.tensor([x], device=get_current_device())
|
||||||
|
dist.all_reduce(temp)
|
||||||
|
return temp.item()
|
||||||
|
|
||||||
|
|
||||||
def add_param(param_list, param_cp_list, *args, **kwargs):
|
def add_param(param_list, param_cp_list, *args, **kwargs):
|
||||||
param = ColoParameter(torch.empty(*args, **kwargs))
|
param = ColoParameter(torch.randn(*args, **kwargs))
|
||||||
param_list.append(param)
|
param_list.append(param)
|
||||||
param_cp_list.append(param.clone())
|
param_cp_list.append(param.clone())
|
||||||
|
|
||||||
|
@ -27,7 +35,7 @@ def check_euqal(param, param_cp):
|
||||||
@parameterize('init_device', [None, torch.device('cpu')])
|
@parameterize('init_device', [None, torch.device('cpu')])
|
||||||
@parameterize('keep_gathered', [True, False])
|
@parameterize('keep_gathered', [True, False])
|
||||||
@parameterize('pin_memory', [True, False])
|
@parameterize('pin_memory', [True, False])
|
||||||
def exam_chunk_init(init_device, keep_gathered, pin_memory):
|
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
pg = ColoProcessGroup()
|
pg = ColoProcessGroup()
|
||||||
my_chunk = AgChunk(
|
my_chunk = AgChunk(
|
||||||
|
@ -56,17 +64,51 @@ def exam_chunk_init(init_device, keep_gathered, pin_memory):
|
||||||
|
|
||||||
if keep_gathered is False:
|
if keep_gathered is False:
|
||||||
assert my_chunk.cpu_shard.size(0) == 1024 // world_size
|
assert my_chunk.cpu_shard.size(0) == 1024 // world_size
|
||||||
|
assert my_chunk.device_type == 'cpu'
|
||||||
|
assert my_chunk.can_move
|
||||||
my_chunk.shard_move(get_current_device())
|
my_chunk.shard_move(get_current_device())
|
||||||
|
else:
|
||||||
|
assert my_chunk.chunk_total.size(0) == 1024
|
||||||
|
assert my_chunk.device_type == 'cuda'
|
||||||
|
assert not my_chunk.can_move
|
||||||
|
|
||||||
|
assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size
|
||||||
|
flag = my_chunk.has_inf_or_nan
|
||||||
|
assert not flag, "has_inf_or_nan is {}".format(flag)
|
||||||
|
|
||||||
my_chunk.access_chunk()
|
my_chunk.access_chunk()
|
||||||
|
assert my_chunk.device_type == 'cuda'
|
||||||
for param, param_cp in zip(param_list, param_cp_list):
|
for param, param_cp in zip(param_list, param_cp_list):
|
||||||
check_euqal(param, param_cp)
|
check_euqal(param, param_cp)
|
||||||
|
|
||||||
|
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
|
||||||
|
my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE)
|
||||||
|
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3
|
||||||
|
assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1
|
||||||
|
assert not my_chunk.can_release
|
||||||
|
|
||||||
|
for param in param_list:
|
||||||
|
my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
|
||||||
|
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
|
||||||
|
|
||||||
|
assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4
|
||||||
|
assert my_chunk.can_reduce
|
||||||
|
my_chunk.reduce()
|
||||||
|
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
|
||||||
|
|
||||||
|
if keep_gathered is False:
|
||||||
|
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
|
||||||
|
assert my_chunk.device_type == 'cuda'
|
||||||
|
assert my_chunk.can_move
|
||||||
|
else:
|
||||||
|
assert my_chunk.chunk_total.size(0) == 1024
|
||||||
|
assert my_chunk.device_type == 'cuda'
|
||||||
|
assert not my_chunk.can_move
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
exam_chunk_init()
|
exam_chunk_basic()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -78,4 +120,4 @@ def test_chunk_function(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_chunk_function(2)
|
test_chunk_function(4)
|
||||||
|
|
Loading…
Reference in New Issue