[NFC] polish comments for Chunk class (#2116)

pull/2120/head
Jiarui Fang 2022-12-12 15:39:31 +08:00 committed by GitHub
parent 09d69e1c25
commit e99edfcb51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 95 additions and 83 deletions

View File

@ -71,8 +71,9 @@ class Chunk:
chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
cpu_shard_init (bool): a flag indicates the local chunk shard is resident on CPU.
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
"""
@ -81,13 +82,12 @@ class Chunk:
self.chunk_size = chunk_size
self.utilized_size = 0
# Here, we use torch process group,
# since ColoProcessGroup might get deprecated soon
self.torch_pg = process_group.dp_process_group()
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
# the chunk size should be able to be divied by the size of GPU
# the chunk size should be divisible by the dp degree
if not keep_gathered:
assert chunk_size % self.pg_size == 0
self.shard_size = chunk_size // self.pg_size
@ -97,13 +97,21 @@ class Chunk:
self.dtype = dtype
device = init_device or get_current_device()
# chunk_temp is a global chunk, which only exists during building the chunks.
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
self.chunk_total = None # we force chunk_total located in CUDA
self.cuda_shard = None # using two attributes for the better interpretation
self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA
# cuda local chunk, which is sharded on GPUs
self.cuda_shard = None
# cpu local chunk, which is sharded on CPUs
self.cpu_shard = None
# is the chunks gathers, which means chunks are duplicated on each process,
# and we should use the cuda_global_chunk.
self.is_gathered = True
# configure the init deivce of the shard
# configure the init device of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
@ -111,17 +119,19 @@ class Chunk:
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size
# each tensor is associated with a TensorInfo to track meta info
# each tensor is associated with a TensorInfo to track its meta info
# (state, offset, end)
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
# the total number of all tensors
# the total number of tensors in the chunk
self.num_tensors = 0
# monitor the states of all tensors
self.tensors_state_monitor: Dict[TensorState, int] = dict()
for state in TensorState:
self.tensors_state_monitor[state] = 0
# some chunks can keep gathered all the time
# so their computation patterns are the same as that of the parameters in DDP
# Record the number of tensors in different states
self.tensor_state_cnter: Dict[TensorState, int] = dict()
for state in TensorState:
self.tensor_state_cnter[state] = 0
# If a chunk is kept gathered,
# they are treated the same as that of the parameters in DDP during training.
self.keep_gathered = keep_gathered
if self.keep_gathered:
pin_memory = False # since this chunk is gathered, it doesn't need to pin
@ -182,7 +192,7 @@ class Chunk:
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_total
return self.cuda_global_chunk
elif self.cuda_shard is not None:
return self.cuda_shard
else:
@ -207,19 +217,19 @@ class Chunk:
if self.keep_gathered:
return False
else:
return self.tensors_state_monitor[TensorState.HOLD] + \
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
return self.tensor_state_cnter[TensorState.HOLD] + \
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
@property
def can_reduce(self):
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors
@property
def has_inf_or_nan(self) -> bool:
"""Check if the chunk has inf or nan values on CUDA.
"""
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
else:
assert self.cuda_shard is not None # only check on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
@ -231,7 +241,7 @@ class Chunk:
"""
assert self.l2_norm is None, "you are calculating the l2 norm twice"
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
valid_tensor = self.cuda_global_chunk[:self.utilized_size]
else:
assert self.cuda_shard is not None # calculate on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
@ -261,7 +271,7 @@ class Chunk:
self.num_tensors += 1
tensor_state = TensorState.HOLD
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.tensors_state_monitor[tensor_state] += 1
self.tensor_state_cnter[tensor_state] += 1
self.utilized_size = new_utilized_size
def close_chunk(self):
@ -277,10 +287,10 @@ class Chunk:
self.valid_end = self.utilized_size - self.shard_begin
if self.chunk_temp.device.type == 'cpu':
self.chunk_total = self.chunk_temp.to(get_current_device())
self.cuda_global_chunk = self.chunk_temp.to(get_current_device())
self.__update_tensors_ptr()
else:
self.chunk_total = self.chunk_temp
self.cuda_global_chunk = self.chunk_temp
self.chunk_temp = None
self.__scatter()
@ -366,19 +376,19 @@ class Chunk:
if self.pg_size == 1:
# tricky code here
# just move chunk_total to cuda_shard
# just move cuda_global_chunk to cuda_shard
# the communication is not necessary
self.__scatter()
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.chunk_total, group=self.torch_pg)
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0))
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
free_storage(self.chunk_total)
free_storage(self.cuda_global_chunk)
self.is_gathered = False
self.__update_tensors_state(TensorState.HOLD)
@ -413,8 +423,8 @@ class Chunk:
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload.
@ -443,7 +453,7 @@ class Chunk:
friend_chunk = self.paired_chunk
if self.is_gathered is True:
assert friend_chunk.is_gathered is True
self.chunk_total.copy_(friend_chunk.chunk_total)
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
self.optim_sync_flag = True
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
self.cuda_shard.copy_(friend_chunk.cuda_shard)
@ -465,8 +475,8 @@ class Chunk:
# sanity check
assert self.cuda_shard is not None
alloc_storage(self.chunk_total)
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
alloc_storage(self.cuda_global_chunk)
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
self.cuda_shard = None
@ -480,11 +490,11 @@ class Chunk:
# sanity check
assert self.cuda_shard is None
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device)
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device)
self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end])
self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end])
free_storage(self.chunk_total)
free_storage(self.cuda_global_chunk)
self.is_gathered = False
def __paired_shard_move(self):
@ -505,15 +515,15 @@ class Chunk:
def __update_tensors_ptr(self) -> None:
# sanity check
assert self.is_gathered
assert type(self.chunk_total) == torch.Tensor
assert type(self.cuda_global_chunk) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape)
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
self.tensors_state_monitor[tensor_info.state] -= 1
self.tensor_state_cnter[tensor_info.state] -= 1
tensor_info.state = next_state
self.tensors_state_monitor[tensor_info.state] += 1
self.tensor_state_cnter[tensor_info.state] += 1
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
@ -543,9 +553,9 @@ class Chunk:
output.append("\tchunk temp:\n")
print_tensor(tensor=self.chunk_temp, prefix='\t\t')
if self.chunk_total is not None and self.chunk_total.storage().size() > 0:
if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0:
output.append("\tchunk total:\n")
print_tensor(tensor=self.chunk_total, prefix='\t\t')
print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t')
if self.cuda_shard is not None:
output.append("\tcuda shard:\n")
@ -561,6 +571,6 @@ class Chunk:
if detailed:
output.append("\ttensor state monitor:\n")
for st in TensorState:
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))
return ''.join(output)

View File

@ -299,7 +299,7 @@ class ZeroDDP(ColoDDP):
reduced = self.chunk_manager.reduce_chunk(chunk)
if reduced:
if chunk.is_gathered:
chunk.chunk_total.div_(chunk.pg_size)
chunk.cuda_global_chunk.div_(chunk.pg_size)
else:
chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements
@ -529,7 +529,7 @@ class ZeroDDP(ColoDDP):
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
if chunk.is_gathered:
chunk.chunk_total.copy_(temp_chunk)
chunk.cuda_global_chunk.copy_(temp_chunk)
elif chunk.cuda_shard is not None:
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
else:

View File

@ -1,20 +1,21 @@
import torch
import torch.distributed as dist
from colossalai.gemini.chunk import Chunk
from colossalai.utils import get_current_device
def get_temp_total_chunk_on_cuda(chunk: Chunk):
if chunk.is_gathered:
return chunk.chunk_total
if chunk.cuda_shard is not None:
shard_temp = chunk.cuda_shard
else:
shard_temp = chunk.cpu_shard.to(get_current_device())
total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device())
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
return total_temp
import torch
import torch.distributed as dist
from colossalai.gemini.chunk import Chunk
from colossalai.utils import get_current_device
def get_temp_total_chunk_on_cuda(chunk: Chunk):
if chunk.is_gathered:
return chunk.cuda_global_chunk
if chunk.cuda_shard is not None:
shard_temp = chunk.cuda_shard
else:
shard_temp = chunk.cpu_shard.to(get_current_device())
total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device())
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
return total_temp

View File

@ -9,10 +9,11 @@ from colossalai.tensor.tensor_spec import ColoTensorSpec
class ColoParamOpHook(ABC):
"""Hook which is triggered by each operation when operands contain ColoParameter.
"""
Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
of ColoParameter.
``post_forward``, ``pre_backward`` and ``post_backward``.
These four methods apply a list of ColoParameter as input args.
"""
@abstractmethod
@ -33,7 +34,8 @@ class ColoParamOpHook(ABC):
class ColoParamOpHookManager:
"""Manage your param op hooks. It only has static methods.
"""
Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
"""
hooks: Tuple[ColoParamOpHook, ...] = tuple()

View File

@ -2,23 +2,22 @@ from typing import Optional
import torch
import torch.distributed as dist
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.logging import get_dist_logger
from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.stateful_tensor import TensorState
@OPHOOKS.register_module
class ZeroHook(BaseOpHook):
"""
A hook to process sharded param for ZeRO method.
Warning: this class has been deprecated after version 0.1.12
"""
def __init__(self,

View File

@ -69,7 +69,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
assert my_chunk.can_move
my_chunk.shard_move(get_current_device())
else:
assert my_chunk.chunk_total.size(0) == 1024
assert my_chunk.cuda_global_chunk.size(0) == 1024
assert my_chunk.device_type == 'cuda'
assert not my_chunk.can_move
@ -82,27 +82,27 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
for param, param_cp in zip(param_list, param_cp_list):
check_euqal(param, param_cp)
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE)
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3
assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 3
assert my_chunk.tensor_state_cnter[TensorState.COMPUTE] == 1
assert not my_chunk.can_release
for param in param_list:
my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
assert my_chunk.can_reduce
my_chunk.reduce()
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
if keep_gathered is False:
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
assert my_chunk.device_type == 'cuda'
assert my_chunk.can_move
else:
assert my_chunk.chunk_total.size(0) == 1024
assert my_chunk.cuda_global_chunk.size(0) == 1024
assert my_chunk.device_type == 'cuda'
assert not my_chunk.can_move