Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

291 lines
10 KiB

from contextlib import contextmanager
from typing import List
import torch
import torch.distributed as dist
from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size
from colossalai.accelerator import get_accelerator
class SeqParallelUtils:
@staticmethod
def marked_as_sp_partial_derived_param(param):
"""
Mark a parameter as partially derived in sequence parallelism.
Args:
param: The parameter to mark as partially derived.
"""
setattr(param, "partial_derived", True)
@staticmethod
def is_sp_partial_derived_param(param):
"""
Check if a parameter is marked as partially derived in sequence parallelism.
Args:
param: The parameter to check.
Returns:
bool: True if the parameter is marked as partially derived, False otherwise.
"""
return getattr(param, "partial_derived", False)
@staticmethod
def allreduce_partial_data_grad(
process_group: ProcessGroup,
model: nn.Module = None,
grads: List[torch.Tensor] = None,
):
"""
Allreduce partial derived gradients across the specified process group.
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
Args:
process_group (ProcessGroup): The process group for gradient synchronization.
model (nn.Module): The model from which gradients will be synchronized.
grads (List[torch.Tensor]): The list of gradients to be synchronized.
only_sp_partial (bool): Whether handle all the parameters or only parameters marked as partial derived.
Raises:
AssertionError: If both `model` and `grads` are provided or neither is provided.
"""
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
# Get the size of the process group, which determines whether synchronization is needed.
group_size = get_world_size(process_group) if process_group is not None else 1
if group_size == 1:
# If the process group size is 1, no synchronization is required.
return
if model is not None:
# If `model` is provided, extract partial derived gradients from the model's parameters.
grads = []
for p in model.parameters():
if p.grad is not None:
if SeqParallelUtils.is_sp_partial_derived_param(p):
grads.append(p.grad.data)
# Flatten and reduce the gradients using the specified process group.
if len(grads) == 0:
return
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
# Unflatten the synchronized gradients and update the model's gradients.
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
else:
# If `grads` are provided explicitly, synchronize those gradients directly.
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
class Randomizer:
"""
Randomizer enables the program to be executed under a different seed within the context.
Example:
```python
randomizer = Randomizer(seed=1024)
with randomizer.fork():
# do something here with seed 1024
do_something()
```
Args:
seed (int): The random seed to set.
enable_cpu (bool): fork the CPU RNG state as well.
with_index (bool): whether to use the index of the randomizer.
"""
_INDEX = 0
def __init__(self, seed: int):
self.seed = seed
# Handle device rng state
# 1. get the current rng state
# 2. set the seed and store the rng state
# 3. recover the original rng state
device_original_rng_state = get_accelerator().get_rng_state()
get_accelerator().manual_seed(seed)
self.device_rng_state = get_accelerator().get_rng_state()
get_accelerator().set_rng_state(device_original_rng_state)
# to the same for cpu rng state
cpu_original_rng_state = torch.get_rng_state()
torch.manual_seed(seed)
self.cpu_rng_state = torch.get_rng_state()
torch.set_rng_state(cpu_original_rng_state)
def _set_device_rng_state(self, rng_state):
get_accelerator().set_rng_state(rng_state)
def _get_device_rng_state(self):
current_state = get_accelerator().get_rng_state()
return current_state
def _set_cpu_rng_state(self, rng_state):
torch.set_rng_state(rng_state)
def _get_cpu_rng_state(self):
current_state = torch.get_rng_state()
return current_state
@contextmanager
def fork_rng(self, enable_cpu: bool = False):
"""
This is a context manager to change the dropout state and recover the original state.
Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try:
current_device_rng_state = self._get_device_rng_state()
self._set_device_rng_state(self.device_rng_state)
if enable_cpu:
current_cpu_rng_state = self._get_cpu_rng_state()
self._set_cpu_rng_state(self.cpu_rng_state)
yield
finally:
self.device_rng_state = self._get_device_rng_state()
self._set_device_rng_state(current_device_rng_state)
if enable_cpu:
self.cpu_rng_state = self._get_cpu_rng_state()
self._set_cpu_rng_state(current_cpu_rng_state)
@staticmethod
def index():
"""
Return the index of the randomizer. The index is useful when the user wants
to introduce some randomness in the program.
Note:
The index will increment by one each time this method is called.
Example:
```python
# assume we need a randomizer to init the weight of different layers
# we can use the index of the randomizer to do so that
# each layer has its own randomizer with a different seed
base_seed = torch.random.initial_seed()
seed = base_seed + Randomizer.index()
randomizer = Randomizer(seed)
with randomizer.fork():
init_weights()
```
"""
idx = Randomizer._INDEX
return idx
@staticmethod
def increment_index():
"""
Increment the index of the randomizer by one.
"""
Randomizer._INDEX += 1
@staticmethod
def reset_index():
"""
Reset the index to zero.
"""
Randomizer._INDEX = 0
@staticmethod
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
"""
Return whether the randomizer index is synchronized across processes.
"""
index = Randomizer.index()
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
dist.all_gather(gathered_index, index_tensor, process_group)
# make sure all the gathered index are the same
for i in range(1, dist.get_world_size(process_group)):
if gathered_index[i] != gathered_index[0]:
return False
return True
@staticmethod
def synchronize_index(process_group: ProcessGroup = None):
"""
All gather the index and pick the largest value.
"""
index = Randomizer.index()
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
dist.all_gather(gathered_index, index_tensor, process_group)
# pick the largest index
for i in range(1, dist.get_world_size(process_group)):
if gathered_index[i] > index_tensor:
index_tensor = gathered_index[i]
# set the index
Randomizer._INDEX = index_tensor.item()
def create_randomizer_with_offset(
seed: int, process_group: ProcessGroup = None, offset_by_rank: bool = True, offset_by_index: bool = True
):
"""
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
Args:
seed (int): The base random seed to set.
process_group (ProcessGroup): the process group to get the rank from.
offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True.
offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True.
Returns:
Randomizer: the randomizer with offset.
"""
base_seed = seed
if offset_by_rank and dist.is_initialized():
rank = dist.get_rank(process_group)
base_seed += rank
if offset_by_index:
# check if the randomizer index is synchronized
is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group)
assert is_synchronized, (
"We detect that the randomizer index is not synchronized across processes."
"This is not allowed when we want to create a randomizer with offset by index."
"Please call Randomizer.synchronize_index() first."
)
base_seed += Randomizer.index()
Randomizer.increment_index()
return Randomizer(seed=base_seed)