2021-10-28 16:21:23 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
2021-12-27 07:04:32 +00:00
|
|
|
from torch.distributed import ReduceOp
|
2021-10-28 16:21:23 +00:00
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
from colossalai.context import ParallelMode
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
|
|
|
|
|
2021-12-27 07:04:32 +00:00
|
|
|
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
|
2021-10-28 16:21:23 +00:00
|
|
|
"""Gathers all tensors from the parallel group and concatenates them in a
|
|
|
|
specific dimension.
|
|
|
|
|
|
|
|
:param tensor: Tensor to be gathered
|
|
|
|
:param dim: The dimension concatenating in
|
|
|
|
:param parallel_mode: Parallel group mode used in this communication
|
2022-01-21 02:44:30 +00:00
|
|
|
:param async_op: Whether operations are asynchronous
|
|
|
|
|
2021-12-13 14:07:01 +00:00
|
|
|
:type tensor: :class:`torch.Tensor`
|
2021-10-28 16:21:23 +00:00
|
|
|
:type dim: int
|
2021-12-13 14:07:01 +00:00
|
|
|
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
2022-01-21 02:44:30 +00:00
|
|
|
:type async_op: bool, optional
|
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
:return: The tensor generated by all-gather
|
2021-12-13 14:07:01 +00:00
|
|
|
:rtype: :class:`torch.Tensor`
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
depth = gpc.get_world_size(parallel_mode)
|
2021-12-27 07:04:32 +00:00
|
|
|
if depth == 1:
|
2022-02-14 03:15:02 +00:00
|
|
|
out = tensor
|
2021-12-27 07:04:32 +00:00
|
|
|
work = None
|
|
|
|
else:
|
|
|
|
shape = list(tensor.shape)
|
|
|
|
shape[0], shape[dim] = shape[dim], shape[0]
|
|
|
|
shape[0] *= depth
|
|
|
|
out = torch.empty(shape, dtype=tensor.dtype, device=get_current_device())
|
|
|
|
temp = list(torch.chunk(out, depth, dim=0))
|
|
|
|
work = dist.all_gather(tensor_list=temp,
|
|
|
|
tensor=tensor.transpose(0, dim).contiguous(),
|
|
|
|
group=gpc.get_group(parallel_mode),
|
|
|
|
async_op=async_op)
|
|
|
|
out = torch.transpose(out, 0, dim)
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
if async_op:
|
2021-12-27 07:04:32 +00:00
|
|
|
return out, work
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
else:
|
|
|
|
return out
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
2021-12-27 07:04:32 +00:00
|
|
|
def reduce_scatter(tensor: Tensor,
|
|
|
|
dim: int,
|
|
|
|
parallel_mode: ParallelMode,
|
|
|
|
op: ReduceOp = ReduceOp.SUM,
|
|
|
|
async_op: bool = False) -> Tensor:
|
2021-10-28 16:21:23 +00:00
|
|
|
"""Reduces all tensors then scatters it in a specific dimension to all
|
|
|
|
members in the parallel group.
|
|
|
|
|
|
|
|
:param tensor: Tensor to be reduced and scattered
|
|
|
|
:param dim: The dimension scattering in
|
|
|
|
:param parallel_mode: Parallel group mode used in this communication
|
2022-01-21 02:44:30 +00:00
|
|
|
:param op: The type of reduce operation
|
|
|
|
:param async_op: Whether operations are asynchronous
|
|
|
|
|
2021-12-13 14:07:01 +00:00
|
|
|
:type tensor: :class:`torch.Tensor`
|
2021-10-28 16:21:23 +00:00
|
|
|
:type dim: int
|
2021-12-13 14:07:01 +00:00
|
|
|
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
2022-01-21 02:44:30 +00:00
|
|
|
:type op: ReduceOp, optional
|
|
|
|
:type async_op: bool, optional
|
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
:return: The tensor generated by reduce-scatter
|
2021-12-13 14:07:01 +00:00
|
|
|
:rtype: :class:`Tensor`
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
depth = gpc.get_world_size(parallel_mode)
|
2021-12-27 07:04:32 +00:00
|
|
|
if depth == 1:
|
|
|
|
out = tensor
|
|
|
|
work = None
|
|
|
|
else:
|
|
|
|
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
|
|
|
|
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=get_current_device())
|
|
|
|
work = dist.reduce_scatter(output=out,
|
|
|
|
input_list=temp,
|
|
|
|
op=op,
|
|
|
|
group=gpc.get_group(parallel_mode),
|
|
|
|
async_op=async_op)
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
if async_op:
|
2021-12-27 07:04:32 +00:00
|
|
|
return out, work
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
else:
|
|
|
|
return out
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
def all_reduce(tensor: Tensor,
|
|
|
|
parallel_mode: ParallelMode,
|
2021-12-27 07:04:32 +00:00
|
|
|
op: ReduceOp = ReduceOp.SUM,
|
|
|
|
async_op: bool = False) -> Tensor:
|
|
|
|
depth = gpc.get_world_size(parallel_mode)
|
|
|
|
if depth == 1:
|
2022-02-14 03:15:02 +00:00
|
|
|
out = tensor
|
2021-12-27 07:04:32 +00:00
|
|
|
work = None
|
|
|
|
else:
|
2022-02-14 03:15:02 +00:00
|
|
|
out = tensor.contiguous()
|
|
|
|
work = dist.all_reduce(out, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
if async_op:
|
2022-02-14 03:15:02 +00:00
|
|
|
return out, work
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
else:
|
2022-02-14 03:15:02 +00:00
|
|
|
return out
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
|
|
|
|
|
2021-12-27 07:04:32 +00:00
|
|
|
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
|
|
|
|
depth = gpc.get_world_size(parallel_mode)
|
|
|
|
if depth == 1:
|
2022-02-14 03:15:02 +00:00
|
|
|
out = tensor
|
2021-12-27 07:04:32 +00:00
|
|
|
work = None
|
|
|
|
else:
|
2022-02-14 03:15:02 +00:00
|
|
|
out = tensor.contiguous()
|
|
|
|
work = dist.broadcast(out, src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
|
2021-12-27 07:04:32 +00:00
|
|
|
if async_op:
|
2022-02-14 03:15:02 +00:00
|
|
|
return out, work
|
2021-12-27 07:04:32 +00:00
|
|
|
else:
|
2022-02-14 03:15:02 +00:00
|
|
|
return out
|
2021-12-27 07:04:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
|
|
|
|
depth = gpc.get_world_size(parallel_mode)
|
|
|
|
if depth == 1:
|
2022-02-14 03:15:02 +00:00
|
|
|
out = tensor
|
2021-12-27 07:04:32 +00:00
|
|
|
work = None
|
|
|
|
else:
|
2022-02-14 03:15:02 +00:00
|
|
|
out = tensor.contiguous()
|
|
|
|
work = dist.reduce(out, dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
|
2021-12-27 07:04:32 +00:00
|
|
|
if async_op:
|
2022-02-14 03:15:02 +00:00
|
|
|
return out, work
|
2021-12-27 07:04:32 +00:00
|
|
|
else:
|
2022-02-14 03:15:02 +00:00
|
|
|
return out
|