2021-10-28 16:21:23 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.utils import get_current_device
|
2022-04-25 05:41:43 +00:00
|
|
|
from typing import Union, List, Tuple
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-04-25 05:41:43 +00:00
|
|
|
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-04-25 05:41:43 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
def send_meta_helper(obj, next_rank, tensor_kwargs):
|
|
|
|
send_shape = torch.tensor(obj.size(), **tensor_kwargs)
|
|
|
|
send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
|
|
|
|
dist.send(send_ndims, next_rank)
|
|
|
|
dist.send(send_shape, next_rank)
|
|
|
|
|
|
|
|
|
|
|
|
def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
|
|
|
"""Sends obj meta information before sending a specific obj.
|
|
|
|
Since the recipient must know the shape of the obj in p2p communications,
|
|
|
|
meta information of the obj should be sent before communications. This function
|
|
|
|
synchronizes with :func:`recv_obj_meta`.
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
2022-06-02 05:48:59 +00:00
|
|
|
obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
|
2022-03-25 05:02:39 +00:00
|
|
|
need_meta (bool, optional): If False, meta information won't be sent.
|
|
|
|
next_rank (int): The rank of the next member in pipeline parallel group.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: False
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
if need_meta:
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
if next_rank is None:
|
|
|
|
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
2022-06-02 05:48:59 +00:00
|
|
|
if isinstance(obj, torch.Tensor):
|
|
|
|
send_obj_nums = torch.tensor(1, **tensor_kwargs)
|
|
|
|
dist.send(send_obj_nums, next_rank)
|
|
|
|
send_meta_helper(obj, next_rank, tensor_kwargs)
|
|
|
|
else:
|
|
|
|
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
|
|
|
|
dist.send(send_obj_nums, next_rank)
|
|
|
|
for tensor_to_send in obj:
|
|
|
|
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
def recv_meta_helper(prev_rank, tensor_kwargs):
|
|
|
|
recv_ndims = torch.empty((), **tensor_kwargs)
|
|
|
|
dist.recv(recv_ndims, prev_rank)
|
|
|
|
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
|
|
|
|
dist.recv(recv_shape, prev_rank)
|
|
|
|
return recv_shape
|
|
|
|
|
|
|
|
|
|
|
|
def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
|
|
|
|
"""Receives obj meta information before receiving a specific obj.
|
|
|
|
Since the recipient must know the shape of the obj in p2p communications,
|
|
|
|
meta information of the obj should be received before communications. This function
|
|
|
|
synchronizes with :func:`send_obj_meta`.
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
2022-06-02 05:48:59 +00:00
|
|
|
obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
|
|
|
|
prev_rank (int): The rank of the source of the obj.
|
2022-03-25 05:02:39 +00:00
|
|
|
|
|
|
|
Returns:
|
2022-06-02 05:48:59 +00:00
|
|
|
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
2022-06-02 05:48:59 +00:00
|
|
|
if obj_shape is None:
|
2021-10-28 16:21:23 +00:00
|
|
|
if prev_rank is None:
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
2022-06-02 05:48:59 +00:00
|
|
|
recv_obj_nums = torch.empty((), **tensor_kwargs)
|
|
|
|
dist.recv(recv_obj_nums, prev_rank)
|
|
|
|
if recv_obj_nums.item() == 1:
|
|
|
|
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
|
|
|
obj_shape = torch.Size(recv_shape)
|
|
|
|
else:
|
|
|
|
obj_shape = []
|
|
|
|
for i in range(recv_obj_nums.item()):
|
|
|
|
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
|
|
|
obj_shape.append(torch.Size(recv_shape))
|
|
|
|
|
|
|
|
return obj_shape
|
2022-01-07 05:22:22 +00:00
|
|
|
|
|
|
|
|
2022-04-25 05:41:43 +00:00
|
|
|
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
|
2022-01-21 02:44:30 +00:00
|
|
|
"""Break a tensor into equal 1D chunks.
|
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
2022-04-25 05:41:43 +00:00
|
|
|
tensor (:class:`torch.Tensor`): Tensor to be split before communication.
|
2022-03-25 05:02:39 +00:00
|
|
|
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
|
2022-01-21 02:44:30 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Returns:
|
2022-04-25 05:41:43 +00:00
|
|
|
:class:`torch.Tensor`: The split tensor
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2022-01-07 05:22:22 +00:00
|
|
|
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
|
|
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
|
|
|
end_index = start_index + partition_size
|
|
|
|
if new_buffer:
|
2022-04-02 09:28:58 +00:00
|
|
|
data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
|
2022-01-07 05:22:22 +00:00
|
|
|
data.copy_(tensor.view(-1)[start_index:end_index])
|
|
|
|
else:
|
|
|
|
data = tensor.view(-1)[start_index:end_index]
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
2022-04-25 05:41:43 +00:00
|
|
|
def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
2022-01-21 02:44:30 +00:00
|
|
|
"""Opposite of above function, gather values from model parallel ranks.
|
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
2022-04-25 05:41:43 +00:00
|
|
|
tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
|
2022-03-25 05:02:39 +00:00
|
|
|
Returns:
|
2022-04-25 05:41:43 +00:00
|
|
|
:class:`torch.Tensor`: The gathered tensor.
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2022-01-07 05:22:22 +00:00
|
|
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
|
|
|
numel = torch.numel(tensor)
|
|
|
|
numel_gathered = world_size * numel
|
2022-04-02 09:28:58 +00:00
|
|
|
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
|
2022-03-09 02:35:05 +00:00
|
|
|
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
|
2022-01-07 05:22:22 +00:00
|
|
|
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
|
|
|
return gathered
|