ColossalAI/colossalai/elixir/chunk/fetcher.py

211 lines
7.8 KiB
Python

from typing import List, Optional
import torch
from .core import Chunk, ChunkGroup, TensorBlock, TensorState
from .scheduler import ChunkScheduler
class ChunkFetcher(object):
"""ChunkFetcher is responsible for fetching and reducing chunks during training.
Any operations on chunks should be done through ChunkFetcher.
args:
scheduler: A ChunkScheduler to schedule evictable chunks.
group: A ChunkGroup to manage chunks.
overlap: Whether to overlap communications.
reduce_always_fp32: Whether to reduce gradients in FP32.
"""
def __init__(self,
scheduler: ChunkScheduler,
group: ChunkGroup,
overlap: bool = False,
reduce_always_fp32: bool = False) -> None:
self.scheduler: ChunkScheduler = scheduler
self.group: ChunkGroup = group
self.reduce_always_fp32 = reduce_always_fp32
self.current_step = -1
self.overlap_flag = overlap
self.main_stream = torch.cuda.current_stream()
self.predict_next_chunk: Optional[Chunk] = None
self.is_fetching: bool = False
self.prefetch_stream = torch.cuda.Stream()
self.reduced_chunk: Optional[Chunk] = None
self.reduced_block: Optional[TensorBlock] = None
self.reduce_stream = torch.cuda.Stream()
def reset(self):
"""Reset the fetcher to the initial state.
Users should call this function before training."""
self.scheduler.reset()
self.current_step = -1
def clear(self):
"""Clear the fetcher.
Users should call this function after training."""
if self.overlap_flag:
torch.cuda.synchronize()
self.predict_next_chunk = None
self.is_fetching = False
if self.reduced_chunk is not None:
self.reduce_call_back()
self.reduced_chunk = None
self.reduced_block = None
self.scheduler.clear()
def trans_to_compute(self, tensors: List[torch.Tensor]):
"""Transform tensors to COMPUTE state.
This function should be called before the compute operators."""
# update tensor states
for t in tensors:
self.group.tensor_trans_state(t, TensorState.COMPUTE)
# chunk operations
chunks = self.group.tensors_to_chunks(tensors)
for chunk in chunks:
self.scheduler.remove(chunk)
return chunks
def trans_to_hold(self, tensors: List[torch.Tensor], phase: str):
"""Transform tensors to HOLD state.
This function should be called after the compute operators."""
assert phase in ('f', 'b')
next_state = TensorState.HOLD if phase == 'f' else TensorState.HOLD_AFTER_BWD
# update tensor states
for t in tensors:
self.group.tensor_trans_state(t, next_state)
# chunk operations
chunks = self.group.tensors_to_chunks(tensors)
for chunk in chunks:
if chunk.scatter_check:
self.scheduler.add(chunk)
def get_one_chunk(self, tensor: torch.Tensor) -> Chunk:
"""Get the chunk of the given tensor."""
return self.group.ten_to_chunk.get(tensor)
def get_chunks(self, tensors: List[torch.Tensor]) -> List[Chunk]:
"""Get the chunks of the given tensors."""
return self.group.tensors_to_chunks(tensors)
def is_in_fused(self, tensor: torch.Tensor):
"""Check whether the given tensor is in a fused chunk."""
chunk = self.get_one_chunk(tensor)
return chunk.rcache_fused
def filter_chunks(self, chunks: List[Chunk]):
"""Filter the accessed chunks, since they are already on the rCache."""
return list(filter(lambda c: not self.group.is_accessed(c), chunks))
def fetch_chunks(self, chunks: List[Chunk]):
"""Fetch chunks needed for this compute operator.
The chunks should be in the COMPUTE state first."""
# make step + 1
self.step()
predict_hit = False
# try to prefetch the next chunk
if self.predict_next_chunk is not None and self.predict_next_chunk in chunks:
if self.is_fetching:
# prefetch hit, wait async prefetch
self.main_stream.wait_stream(self.prefetch_stream)
# clear prefetch information
self.predict_next_chunk = None
self.is_fetching = False
predict_hit = True
# filter accessed chunks
scattered = self.filter_chunks(chunks)
# sanity check: upload should wait for prefetch
if self.predict_next_chunk is not None:
assert len(scattered) == 0
# all chunks are on the rcache
if len(scattered) == 0:
# prefetch if there is a hit above
if predict_hit:
self.prefetch(chunks)
return
for chunk in scattered:
# if the rcache is not enough, just release a chunk
if not self.group.rcache_enough_check(chunk):
maybe_chunk = self.scheduler.top()
# print(f'Evicting {chunk.chunk_id} -> {maybe_chunk.chunk_id}')
if maybe_chunk is None:
raise RuntimeError('R cache is not enough. Try to allocate more.')
self.scheduler.remove(maybe_chunk)
self.group.release_chunk(maybe_chunk)
# print('Accessing', chunk.chunk_id)
self.group.access_chunk(chunk)
if self.overlap_flag:
assert self.predict_next_chunk is None
self.prefetch(chunks)
def reduce_call_back(self):
self.reduced_chunk.update_extra_reduce_info(self.reduced_block)
if self.reduced_block is not None:
self.group.rcache.free_public_block(self.reduced_block)
def reduce_chunk(self, chunk: Chunk):
"""Reduce and scatter the given gradient chunk."""
if not chunk.reduce_check:
return False
self.scheduler.remove(chunk)
if not self.overlap_flag:
# reduce the chunk if not overlapped
self.group.reduce_chunk(chunk, always_fp32=self.reduce_always_fp32, sync=True)
else:
# wait main stream for its computation
self.reduce_stream.wait_stream(self.main_stream)
# main stream should wait reduce stream
# if there is a block recycle
if self.reduced_chunk is not None:
self.main_stream.wait_stream(self.reduce_stream)
self.reduce_call_back()
with torch.cuda.stream(self.reduce_stream):
self.reduced_chunk = chunk
self.reduced_block = self.group.reduce_chunk(chunk, always_fp32=self.reduce_always_fp32, sync=False)
def prefetch(self, chunks: List[Chunk]):
"""Prefetch the next used chunk."""
next_chunk = self.scheduler.get_next_chunk(chunks)
self.predict_next_chunk = next_chunk
# return if there is no next scattered chunk
if next_chunk is None or self.group.is_accessed(next_chunk):
return
evict_chunk = None
if not self.group.rcache_enough_check(next_chunk):
maybe_chunk = self.scheduler.top()
# if there is no chunk can be evicted, just return
if maybe_chunk is None:
return
# otherwise, release this chunk
self.scheduler.remove(maybe_chunk)
evict_chunk = maybe_chunk
with torch.cuda.stream(self.prefetch_stream):
# wait main stream
self.prefetch_stream.wait_stream(self.main_stream)
self.is_fetching = True
if evict_chunk is not None:
self.group.release_chunk(evict_chunk)
self.group.access_chunk(next_chunk)
def step(self):
"""Update the scheduler."""
self.scheduler.step()
self.current_step += 1