mirror of https://github.com/hpcaitech/ColossalAI
[zero] add unit test for AgChunk's append, close, access (#1423)
parent
c577ed016e
commit
4fb3c52cf0
|
@ -36,7 +36,7 @@ class AgChunk:
|
||||||
self.utilized_size = 0
|
self.utilized_size = 0
|
||||||
# Here, we use torch process group,
|
# Here, we use torch process group,
|
||||||
# since ColoProcessGroup might get deprecated soon
|
# since ColoProcessGroup might get deprecated soon
|
||||||
self.torch_pg = process_group.dp_process_group
|
self.torch_pg = process_group.dp_process_group()
|
||||||
self.pg_size = dist.get_world_size(self.torch_pg)
|
self.pg_size = dist.get_world_size(self.torch_pg)
|
||||||
self.pg_rank = dist.get_rank(self.torch_pg)
|
self.pg_rank = dist.get_rank(self.torch_pg)
|
||||||
|
|
||||||
|
@ -69,6 +69,8 @@ class AgChunk:
|
||||||
# some chunks can keep gathered all the time
|
# some chunks can keep gathered all the time
|
||||||
# so their computation patterns are the same as that of the parameters in DDP
|
# so their computation patterns are the same as that of the parameters in DDP
|
||||||
self.keep_gathered = keep_gathered
|
self.keep_gathered = keep_gathered
|
||||||
|
if self.keep_gathered:
|
||||||
|
pin_memory = False # since this chunk is gathered, it doesn't need to pin
|
||||||
|
|
||||||
# if pin_memory is True, we allocate a piece of CPU pin-memory
|
# if pin_memory is True, we allocate a piece of CPU pin-memory
|
||||||
# for it all the time
|
# for it all the time
|
||||||
|
@ -134,7 +136,7 @@ class AgChunk:
|
||||||
if new_utilized_size > self.chunk_size:
|
if new_utilized_size > self.chunk_size:
|
||||||
raise ChunkFullError
|
raise ChunkFullError
|
||||||
|
|
||||||
self.chunk_temp[self.utilized_size: new_utilized_size].copy_(tensor.flatten())
|
self.chunk_temp[self.utilized_size: new_utilized_size].copy_(tensor.data.flatten())
|
||||||
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
||||||
tensor.data = self.chunk_temp[self.utilized_size: new_utilized_size].view(tensor.shape)
|
tensor.data = self.chunk_temp[self.utilized_size: new_utilized_size].view(tensor.shape)
|
||||||
|
|
||||||
|
@ -145,7 +147,7 @@ class AgChunk:
|
||||||
self.tensors_state_monitor[tensor_state] += 1
|
self.tensors_state_monitor[tensor_state] += 1
|
||||||
self.utilized_size = new_utilized_size
|
self.utilized_size = new_utilized_size
|
||||||
|
|
||||||
def close_chunk(self, shard_dev: torch.device):
|
def close_chunk(self, shard_dev: Optional[torch.device] = None):
|
||||||
"""Close the chunk. Any tensor can't be appended to a closed chunk.
|
"""Close the chunk. Any tensor can't be appended to a closed chunk.
|
||||||
"""
|
"""
|
||||||
# sanity check
|
# sanity check
|
||||||
|
@ -159,6 +161,14 @@ class AgChunk:
|
||||||
|
|
||||||
self.__scatter()
|
self.__scatter()
|
||||||
|
|
||||||
|
if self.keep_gathered:
|
||||||
|
if shard_dev is None:
|
||||||
|
shard_dev = get_current_device()
|
||||||
|
else:
|
||||||
|
assert shard_dev.type == 'cuda'
|
||||||
|
elif shard_dev is None:
|
||||||
|
shard_dev = torch.device('cpu')
|
||||||
|
|
||||||
if self.pin_memory or shard_dev.type == 'cpu':
|
if self.pin_memory or shard_dev.type == 'cpu':
|
||||||
self.cpu_shard = torch.empty(self.shard_size,
|
self.cpu_shard = torch.empty(self.shard_size,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
|
@ -364,3 +374,42 @@ class AgChunk:
|
||||||
for tensor_info in self.tensors_info.values():
|
for tensor_info in self.tensors_info.values():
|
||||||
if prev_state is None or tensor_info.state == prev_state:
|
if prev_state is None or tensor_info.state == prev_state:
|
||||||
self.__update_one_tensor_info(tensor_info, next_state)
|
self.__update_one_tensor_info(tensor_info, next_state)
|
||||||
|
|
||||||
|
def __repr__(self, detailed: bool = False):
|
||||||
|
output = [
|
||||||
|
"AgChunk Information:\n",
|
||||||
|
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(
|
||||||
|
self.chunk_size, self.dtype, self.pg_size),
|
||||||
|
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
|
||||||
|
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
def print_tensor(tensor, prefix=''):
|
||||||
|
output.append("{}shape: {}, dtype: {}, device: {}\n".format(
|
||||||
|
prefix, tensor.shape, tensor.dtype, tensor.device))
|
||||||
|
|
||||||
|
if self.chunk_temp is not None:
|
||||||
|
output.append("\tchunk temp:\n")
|
||||||
|
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
|
||||||
|
|
||||||
|
if self.chunk_total is not None and self.chunk_total.storage().size() > 0:
|
||||||
|
output.append("\tchunk total:\n")
|
||||||
|
print_tensor(tensor=self.chunk_total, prefix='\t\t')
|
||||||
|
|
||||||
|
if self.cuda_shard is not None:
|
||||||
|
output.append("\tcuda shard:\n")
|
||||||
|
print_tensor(tensor=self.cuda_shard, prefix='\t\t')
|
||||||
|
|
||||||
|
if self.cpu_shard is not None:
|
||||||
|
output.append("\tcpu shard:\n")
|
||||||
|
print_tensor(tensor=self.cpu_shard, prefix='\t\t')
|
||||||
|
|
||||||
|
memory_info = self.memory_usage
|
||||||
|
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu']))
|
||||||
|
|
||||||
|
if detailed:
|
||||||
|
output.append("\ttensor state monitor:\n")
|
||||||
|
for st in TensorState:
|
||||||
|
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
|
||||||
|
|
||||||
|
return ''.join(output)
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
import torch
|
||||||
|
import colossalai
|
||||||
|
import pytest
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from functools import partial
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
||||||
|
from colossalai.utils import free_port, get_current_device
|
||||||
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||||
|
from colossalai.tensor import ColoParameter
|
||||||
|
from colossalai.gemini.ag_chunk import AgChunk
|
||||||
|
|
||||||
|
|
||||||
|
def add_param(param_list, param_cp_list, *args, **kwargs):
|
||||||
|
param = ColoParameter(torch.empty(*args, **kwargs))
|
||||||
|
param_list.append(param)
|
||||||
|
param_cp_list.append(param.clone())
|
||||||
|
|
||||||
|
|
||||||
|
def check_euqal(param, param_cp):
|
||||||
|
if param.device != param_cp.device:
|
||||||
|
temp = param.data.to(param_cp.device)
|
||||||
|
else:
|
||||||
|
temp = param.data
|
||||||
|
return torch.equal(temp, param_cp.data)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('init_device', [None, torch.device('cpu')])
|
||||||
|
@parameterize('keep_gathered', [True, False])
|
||||||
|
@parameterize('pin_memory', [True, False])
|
||||||
|
def exam_chunk_init(init_device, keep_gathered, pin_memory):
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
pg = ColoProcessGroup()
|
||||||
|
my_chunk = AgChunk(
|
||||||
|
chunk_size=1024,
|
||||||
|
process_group=pg,
|
||||||
|
dtype=torch.float32,
|
||||||
|
init_device=init_device,
|
||||||
|
keep_gathered=keep_gathered,
|
||||||
|
pin_memory=pin_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
param_list = []
|
||||||
|
param_cp_list = []
|
||||||
|
|
||||||
|
add_param(param_list, param_cp_list, 8, 8, 8, device='cuda')
|
||||||
|
add_param(param_list, param_cp_list, 4, 4)
|
||||||
|
add_param(param_list, param_cp_list, 4, 8, 2, device='cuda')
|
||||||
|
add_param(param_list, param_cp_list, 1, 1, 5)
|
||||||
|
|
||||||
|
for param in param_list:
|
||||||
|
my_chunk.append_tensor(param)
|
||||||
|
assert my_chunk.utilized_size == 597
|
||||||
|
for param, param_cp in zip(param_list, param_cp_list):
|
||||||
|
check_euqal(param, param_cp)
|
||||||
|
my_chunk.close_chunk()
|
||||||
|
|
||||||
|
if keep_gathered is False:
|
||||||
|
assert my_chunk.cpu_shard.size(0) == 1024 // world_size
|
||||||
|
my_chunk.shard_move(get_current_device())
|
||||||
|
|
||||||
|
my_chunk.access_chunk()
|
||||||
|
|
||||||
|
for param, param_cp in zip(param_list, param_cp_list):
|
||||||
|
check_euqal(param, param_cp)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
exam_chunk_init()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_chunk_function(world_size):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_chunk_function(2)
|
Loading…
Reference in New Issue