2021-10-28 16:21:23 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
2021-12-27 07:04:32 +00:00
|
|
|
from typing import Optional, Tuple
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
import torch
|
2022-02-14 03:15:02 +00:00
|
|
|
from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
|
2021-10-28 16:21:23 +00:00
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from torch import Tensor
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
2022-02-14 03:15:02 +00:00
|
|
|
from ._utils import get_parallel_mode_from_env
|
|
|
|
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
|
2022-01-10 10:05:58 +00:00
|
|
|
|
2022-02-14 03:15:02 +00:00
|
|
|
class _Linear3D(torch.autograd.Function):
|
|
|
|
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
@staticmethod
|
|
|
|
@custom_fwd(cast_inputs=torch.float16)
|
2021-12-27 07:04:32 +00:00
|
|
|
def forward(ctx,
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
input_: Tensor,
|
|
|
|
weight: Tensor,
|
|
|
|
bias: Optional[Tensor],
|
|
|
|
input_parallel_mode: ParallelMode,
|
|
|
|
weight_parallel_mode: ParallelMode,
|
|
|
|
output_parallel_mode: ParallelMode,
|
|
|
|
input_dim: int = 0,
|
|
|
|
weight_dim: int = -1,
|
|
|
|
output_dim: int = 0) -> Tensor:
|
|
|
|
ctx.use_bias = bias is not None
|
|
|
|
|
|
|
|
input_ = all_gather(input_, input_dim, input_parallel_mode)
|
2022-02-17 14:03:39 +00:00
|
|
|
weight = all_gather(weight, weight_dim, weight_parallel_mode)
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
ctx.save_for_backward(input_, weight)
|
|
|
|
|
|
|
|
output = torch.matmul(input_, weight)
|
|
|
|
output = reduce_scatter(output, output_dim, output_parallel_mode)
|
|
|
|
|
|
|
|
if bias is not None:
|
|
|
|
output += bias
|
|
|
|
|
|
|
|
ctx.input_parallel_mode = input_parallel_mode
|
|
|
|
ctx.weight_parallel_mode = weight_parallel_mode
|
|
|
|
ctx.output_parallel_mode = output_parallel_mode
|
|
|
|
ctx.input_dim = input_dim
|
|
|
|
ctx.weight_dim = weight_dim
|
|
|
|
ctx.output_dim = output_dim
|
|
|
|
return output
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@custom_bwd
|
2021-12-27 07:04:32 +00:00
|
|
|
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
input_, weight = ctx.saved_tensors
|
|
|
|
with torch.no_grad():
|
2021-12-27 07:04:32 +00:00
|
|
|
output_grad = all_gather(output_grad, ctx.output_dim, ctx.output_parallel_mode)
|
|
|
|
|
|
|
|
async_ops = list()
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
|
|
|
|
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
2021-12-27 07:04:32 +00:00
|
|
|
input_grad, op = reduce_scatter(input_grad, ctx.input_dim, ctx.input_parallel_mode, async_op=True)
|
|
|
|
async_ops.append(op)
|
|
|
|
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
weight_grad = torch.matmul(
|
2021-12-27 07:04:32 +00:00
|
|
|
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
2022-02-17 14:03:39 +00:00
|
|
|
weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True)
|
2021-12-27 07:04:32 +00:00
|
|
|
async_ops.append(op)
|
|
|
|
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
if ctx.use_bias:
|
2021-12-27 07:04:32 +00:00
|
|
|
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
|
|
|
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
|
|
|
async_ops.append(op)
|
2022-02-14 03:15:02 +00:00
|
|
|
else:
|
|
|
|
bias_grad = None
|
2021-12-27 07:04:32 +00:00
|
|
|
|
|
|
|
for op in async_ops:
|
|
|
|
if op is not None:
|
|
|
|
op.wait()
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
|
|
|
|
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
|
|
|
|
|
|
|
|
|
2022-02-14 03:15:02 +00:00
|
|
|
def linear_3d(input_: Tensor,
|
|
|
|
weight: Tensor,
|
|
|
|
bias: Optional[Tensor],
|
|
|
|
input_parallel_mode: ParallelMode,
|
|
|
|
weight_parallel_mode: ParallelMode,
|
|
|
|
output_parallel_mode: ParallelMode,
|
|
|
|
input_dim: int = 0,
|
|
|
|
weight_dim: int = -1,
|
|
|
|
output_dim: int = 0) -> Tensor:
|
2022-03-25 05:02:39 +00:00
|
|
|
r"""Linear layer for 3D parallelism.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_ (:class:`torch.tensor`): input matrix.
|
|
|
|
weight (:class:`torch.tensor`): matrix of weight.
|
|
|
|
bias (:class:`torch.tensor`): matrix of bias.
|
|
|
|
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
|
|
|
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
|
|
|
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
|
|
|
input_dim (int, optional): dimension of input, defaults to 0.
|
|
|
|
weight_dim (int, optional): dimension of weight, defaults to -1.
|
|
|
|
output_dim (int, optional): dimension of output, defaults to 0.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
2022-01-10 10:05:58 +00:00
|
|
|
"""
|
2022-02-14 03:15:02 +00:00
|
|
|
return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode,
|
|
|
|
input_dim, weight_dim, output_dim)
|
|
|
|
|
|
|
|
|
|
|
|
class _Classifier3D(torch.autograd.Function):
|
|
|
|
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
@staticmethod
|
|
|
|
@custom_fwd(cast_inputs=torch.float16)
|
2021-12-27 07:04:32 +00:00
|
|
|
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
|
|
|
|
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
ctx.use_bias = bias is not None
|
|
|
|
|
|
|
|
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
|
|
|
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
|
|
|
weight = broadcast(weight, src_rank, input_parallel_mode)
|
|
|
|
ctx.save_for_backward(input_, weight)
|
|
|
|
|
|
|
|
output = torch.matmul(input_, weight.transpose(0, 1))
|
|
|
|
output = all_reduce(output, output_parallel_mode)
|
|
|
|
|
|
|
|
if bias is not None:
|
|
|
|
output += bias
|
|
|
|
|
|
|
|
ctx.src_rank = src_rank
|
|
|
|
ctx.input_parallel_mode = input_parallel_mode
|
|
|
|
ctx.weight_parallel_mode = weight_parallel_mode
|
|
|
|
ctx.output_parallel_mode = output_parallel_mode
|
|
|
|
return output
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@custom_bwd
|
|
|
|
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
|
|
input_, weight = ctx.saved_tensors
|
|
|
|
with torch.no_grad():
|
|
|
|
async_ops = list()
|
|
|
|
|
|
|
|
weight_grad = torch.matmul(
|
|
|
|
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
|
|
|
|
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
|
|
|
|
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
|
|
|
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
|
|
|
async_ops.append(op)
|
|
|
|
else:
|
|
|
|
weight_grad = None
|
|
|
|
|
|
|
|
if ctx.use_bias:
|
|
|
|
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
|
|
|
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
|
|
|
|
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
|
|
|
async_ops.append(op)
|
2022-02-14 03:15:02 +00:00
|
|
|
else:
|
|
|
|
bias_grad = None
|
2021-12-27 07:04:32 +00:00
|
|
|
|
|
|
|
input_grad = torch.matmul(output_grad, weight)
|
|
|
|
|
|
|
|
for op in async_ops:
|
|
|
|
if op is not None:
|
|
|
|
op.wait()
|
|
|
|
|
|
|
|
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
|
|
|
|
|
|
|
|
|
2022-02-14 03:15:02 +00:00
|
|
|
def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
|
|
|
|
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
|
2022-03-25 05:02:39 +00:00
|
|
|
r"""3D parallel classifier.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_ (:class:`torch.tensor`): input matrix.
|
|
|
|
weight (:class:`torch.tensor`): matrix of weight.
|
|
|
|
bias (:class:`torch.tensor`): matrix of bias.
|
|
|
|
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
|
|
|
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
|
|
|
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
2022-01-10 10:05:58 +00:00
|
|
|
"""
|
2022-02-14 03:15:02 +00:00
|
|
|
return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode)
|
|
|
|
|
|
|
|
|
|
|
|
class _Layernorm3D(torch.autograd.Function):
|
|
|
|
|
2021-12-27 07:04:32 +00:00
|
|
|
@staticmethod
|
|
|
|
@custom_fwd(cast_inputs=torch.float32)
|
2022-04-14 03:43:56 +00:00
|
|
|
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
|
2021-12-27 07:04:32 +00:00
|
|
|
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
output_parallel_mode: ParallelMode) -> Tensor:
|
2021-12-27 07:04:32 +00:00
|
|
|
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
mu = input_ - mean
|
2021-12-27 07:04:32 +00:00
|
|
|
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
sigma = torch.sqrt(var + eps)
|
|
|
|
|
|
|
|
ctx.save_for_backward(mu, sigma, weight)
|
|
|
|
|
|
|
|
z = mu / sigma
|
2022-04-14 03:43:56 +00:00
|
|
|
output = weight * z
|
|
|
|
if bias is not None:
|
|
|
|
output = output + bias
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
|
2022-04-14 03:43:56 +00:00
|
|
|
ctx.use_bias = bias is not None
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
ctx.normalized_shape = normalized_shape
|
|
|
|
ctx.input_parallel_mode = input_parallel_mode
|
|
|
|
ctx.weight_parallel_mode = weight_parallel_mode
|
|
|
|
ctx.output_parallel_mode = output_parallel_mode
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@custom_bwd
|
2021-12-27 07:04:32 +00:00
|
|
|
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
mu, sigma, weight = ctx.saved_tensors
|
|
|
|
with torch.no_grad():
|
2022-04-14 03:43:56 +00:00
|
|
|
weight_grad = output_grad * mu / sigma
|
|
|
|
if ctx.use_bias:
|
|
|
|
bias_grad = output_grad
|
|
|
|
weight_grad = torch.stack([bias_grad, weight_grad]).contiguous()
|
|
|
|
else:
|
|
|
|
bias_grad = None
|
|
|
|
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1]))
|
|
|
|
weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode)
|
|
|
|
weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode)
|
|
|
|
if ctx.use_bias:
|
|
|
|
bias_grad, weight_grad = weight_grad[0], weight_grad[1]
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
|
|
|
|
dz = output_grad * weight
|
|
|
|
dvar = dz * mu * (-0.5) * sigma**(-3)
|
|
|
|
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
|
|
|
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
|
|
|
|
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
|
|
|
|
2021-12-27 07:04:32 +00:00
|
|
|
input_grad = dz / sigma + dvar * 2 * mu / \
|
|
|
|
ctx.normalized_shape + dmean / ctx.normalized_shape
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
|
|
|
|
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
2022-04-14 03:43:56 +00:00
|
|
|
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
|
2022-02-14 03:15:02 +00:00
|
|
|
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
|
|
|
output_parallel_mode: ParallelMode) -> Tensor:
|
2022-03-25 05:02:39 +00:00
|
|
|
r"""3D parallel Layernorm.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_ (:class:`torch.tensor`): input matrix.
|
|
|
|
weight (:class:`torch.tensor`): matrix of weight.
|
|
|
|
bias (:class:`torch.tensor`): matrix of bias.
|
|
|
|
normalized_shape (int): input shape from an expected input of size.
|
|
|
|
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
|
|
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
|
|
|
If a single integer is used, it is treated as a singleton list, and this module will
|
|
|
|
normalize over the last dimension which is expected to be of that specific size.
|
|
|
|
eps (float): a value added to the denominator for numerical stability
|
|
|
|
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
|
|
|
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
|
|
|
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
2022-02-14 03:15:02 +00:00
|
|
|
"""
|
|
|
|
return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode,
|
|
|
|
output_parallel_mode)
|
|
|
|
|
|
|
|
|
|
|
|
def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
2022-03-25 05:02:39 +00:00
|
|
|
r"""Splits 3D parallel tensor in specified dimension.
|
2022-02-14 03:15:02 +00:00
|
|
|
|
2022-03-31 03:36:56 +00:00
|
|
|
Args:
|
2022-03-25 05:02:39 +00:00
|
|
|
tensor (:class:`torch.tensor`): Input tensor.
|
|
|
|
dim (int): Specified dimension in which to split.
|
|
|
|
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode.
|
2022-02-14 03:15:02 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Returns:
|
|
|
|
:class:`torch.tensor`: The tensor has been split.
|
2022-03-09 02:31:43 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Note:
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
2022-02-14 03:15:02 +00:00
|
|
|
"""
|
2022-04-02 08:12:04 +00:00
|
|
|
dim_size = tensor.size(dim)
|
|
|
|
world_size = gpc.get_world_size(parallel_mode)
|
|
|
|
assert dim_size % world_size == 0, \
|
|
|
|
f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
|
|
|
f'cannot split tensor evenly'
|
2022-02-14 03:15:02 +00:00
|
|
|
if tensor.size(dim) <= 1:
|
|
|
|
return tensor
|
|
|
|
output = torch.chunk(tensor, gpc.get_world_size(parallel_mode),
|
|
|
|
dim=dim)[gpc.get_local_rank(parallel_mode)].contiguous()
|
|
|
|
return output
|
|
|
|
|
2022-01-21 02:44:30 +00:00
|
|
|
|
2022-02-14 03:15:02 +00:00
|
|
|
def split_batch_3d(input_: Tensor,
|
2022-03-09 02:31:43 +00:00
|
|
|
dim: int = 0,
|
|
|
|
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
|
|
|
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
|
2022-03-25 05:02:39 +00:00
|
|
|
r"""Splits 3D tensor in batch.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_ (:class:`torch.tensor`): Input tensor.
|
|
|
|
dim (int): Specified dimension in which to split.
|
|
|
|
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): input parallel mode.
|
|
|
|
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): weight parallel mode.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
:class:`torch.tensor`: The tensor has been split.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2022-04-02 08:12:04 +00:00
|
|
|
dim_size = input_.size(dim)
|
2022-02-14 03:15:02 +00:00
|
|
|
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
|
|
|
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
2022-04-02 08:12:04 +00:00
|
|
|
weight_world_size = gpc.get_world_size(weight_parallel_mode)
|
|
|
|
input_world_size = gpc.get_world_size(input_parallel_mode)
|
|
|
|
|
|
|
|
assert dim_size % (input_world_size*weight_world_size) == 0, \
|
|
|
|
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).'
|
|
|
|
|
|
|
|
if input_.size(dim) <= 1:
|
|
|
|
return input_
|
|
|
|
output = torch.chunk(input_, weight_world_size,
|
2021-12-27 07:04:32 +00:00
|
|
|
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
2022-04-02 08:12:04 +00:00
|
|
|
output = torch.chunk(output, input_world_size,
|
2021-12-27 07:04:32 +00:00
|
|
|
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
|
|
|
return output
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
2022-02-14 03:15:02 +00:00
|
|
|
class _ReduceTensor3D(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, input_, parallel_mode):
|
|
|
|
return all_reduce(input_, parallel_mode)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, output_grad):
|
|
|
|
return output_grad, None
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
2022-03-25 05:02:39 +00:00
|
|
|
r"""All-reduce the input
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tensor (:class:`torch.tensor`): Input tensor.
|
|
|
|
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode.
|
2022-03-09 02:31:43 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Note:
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
2022-02-14 03:15:02 +00:00
|
|
|
"""
|
|
|
|
return _ReduceTensor3D.apply(tensor, parallel_mode)
|
|
|
|
|
|
|
|
|
2022-02-17 14:03:39 +00:00
|
|
|
class _AllGatherTensor3D(torch.autograd.Function):
|
2022-02-14 03:15:02 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2022-02-17 14:03:39 +00:00
|
|
|
def forward(ctx, input_, dim, parallel_mode):
|
|
|
|
ctx.dim = dim
|
2022-02-14 03:15:02 +00:00
|
|
|
ctx.parallel_mode = parallel_mode
|
2022-02-17 14:03:39 +00:00
|
|
|
output = all_gather(input_, dim, parallel_mode)
|
|
|
|
return output
|
2022-02-14 03:15:02 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, output_grad):
|
2022-02-17 14:03:39 +00:00
|
|
|
input_grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode)
|
|
|
|
return input_grad, None, None
|
2022-02-14 03:15:02 +00:00
|
|
|
|
|
|
|
|
2022-02-17 14:03:39 +00:00
|
|
|
def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
2022-03-25 05:02:39 +00:00
|
|
|
r"""All-reduce the gradient in backward pass.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tensor (:class:`torch.tensor`): Input tensor.
|
|
|
|
dim (int): Dimension to gather.
|
|
|
|
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode.
|
2022-03-09 02:31:43 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Note:
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
2022-02-14 03:15:02 +00:00
|
|
|
"""
|
2022-02-17 14:03:39 +00:00
|
|
|
return _AllGatherTensor3D.apply(tensor, dim, parallel_mode)
|
2022-02-14 03:15:02 +00:00
|
|
|
|
|
|
|
|
|
|
|
class _ReduceScatterTensor3D(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, input_, dim, parallel_mode):
|
|
|
|
ctx.dim = dim
|
|
|
|
ctx.parallel_mode = parallel_mode
|
|
|
|
return reduce_scatter(input_, dim, parallel_mode)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, output_grad):
|
|
|
|
input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
|
|
|
|
return input_grad, None, None
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
2022-03-25 05:02:39 +00:00
|
|
|
r"""Reduce-scatter the input.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tensor (:class:`torch.tensor`): Input tensor.
|
|
|
|
dim (int): Dimension to scatter.
|
|
|
|
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): Parallel mode.
|
2022-03-09 02:31:43 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Note:
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
2022-02-14 03:15:02 +00:00
|
|
|
"""
|
2022-04-02 08:12:04 +00:00
|
|
|
dim_size = tensor.size(dim)
|
|
|
|
world_size = gpc.get_world_size(parallel_mode)
|
|
|
|
assert dim_size % world_size == 0, \
|
|
|
|
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).'
|
|
|
|
|
2022-02-14 03:15:02 +00:00
|
|
|
return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode)
|
|
|
|
|
|
|
|
|
|
|
|
class _ReduceByBatch3D(torch.autograd.Function):
|
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
@staticmethod
|
2021-12-27 07:04:32 +00:00
|
|
|
@custom_fwd(cast_inputs=torch.float32)
|
2021-12-29 15:32:10 +00:00
|
|
|
def forward(ctx,
|
|
|
|
input_: Tensor,
|
|
|
|
input_parallel_mode: ParallelMode,
|
|
|
|
weight_parallel_mode: ParallelMode,
|
|
|
|
reduce_mean: bool = False) -> Tensor:
|
2021-12-27 07:04:32 +00:00
|
|
|
output = all_reduce(input_, input_parallel_mode)
|
|
|
|
output = all_reduce(output, weight_parallel_mode)
|
2021-12-29 15:32:10 +00:00
|
|
|
ctx.reduce_mean = reduce_mean
|
|
|
|
if reduce_mean:
|
|
|
|
reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode)
|
|
|
|
ctx.reduce_size = reduce_size
|
|
|
|
return output.clone() / reduce_size
|
2021-12-27 07:04:32 +00:00
|
|
|
return output.clone()
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
@staticmethod
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
@custom_bwd
|
2021-12-27 07:04:32 +00:00
|
|
|
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
2021-12-29 15:32:10 +00:00
|
|
|
if ctx.reduce_mean:
|
|
|
|
return output_grad / ctx.reduce_size, None, None, None
|
|
|
|
else:
|
|
|
|
return output_grad, None, None, None
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
2022-02-14 03:15:02 +00:00
|
|
|
def reduce_by_batch_3d(tensor: Tensor,
|
|
|
|
input_parallel_mode: ParallelMode,
|
|
|
|
weight_parallel_mode: ParallelMode,
|
|
|
|
reduce_mean: bool = False) -> Tensor:
|
2022-03-25 05:02:39 +00:00
|
|
|
r"""All-reduce the input from the model parallel region.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
|
|
|
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
|
|
|
reduce_mean (bool, optional): If set to ``True``, it will divide the output by
|
|
|
|
(input parallel size * weight parallel size), default to False.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
2022-02-14 03:15:02 +00:00
|
|
|
"""
|
|
|
|
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
|
|
|
|
|
|
|
|
|
|
|
|
class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function):
|
2022-03-25 05:02:39 +00:00
|
|
|
r"""broadcast weight from diagonal.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_ (:class:`torch.tensor`): input matrix.
|
|
|
|
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
|
|
|
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
|
|
|
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
|
|
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
2022-01-10 10:05:58 +00:00
|
|
|
"""
|
2022-02-14 03:15:02 +00:00
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
@staticmethod
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
@custom_fwd(cast_inputs=torch.float16)
|
2021-12-27 07:04:32 +00:00
|
|
|
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
2021-10-28 16:21:23 +00:00
|
|
|
output_parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
|
|
|
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
2021-12-27 07:04:32 +00:00
|
|
|
output = broadcast(input_, src_rank, input_parallel_mode)
|
2021-10-28 16:21:23 +00:00
|
|
|
ctx.src_rank = src_rank
|
2021-12-27 07:04:32 +00:00
|
|
|
ctx.input_parallel_mode = input_parallel_mode
|
|
|
|
ctx.weight_parallel_mode = weight_parallel_mode
|
|
|
|
ctx.output_parallel_mode = output_parallel_mode
|
|
|
|
return output
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
@staticmethod
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
@custom_bwd
|
2021-12-27 07:04:32 +00:00
|
|
|
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
|
|
|
input_grad = reduce(output_grad, ctx.src_rank, ctx.input_parallel_mode)
|
|
|
|
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
|
|
|
input_grad = all_reduce(input_grad, ctx.weight_parallel_mode)
|
|
|
|
else:
|
|
|
|
input_grad = None
|
|
|
|
return input_grad, None, None, None
|
2022-02-14 03:15:02 +00:00
|
|
|
|
|
|
|
|
|
|
|
def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode,
|
|
|
|
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode,
|
|
|
|
output_parallel_mode)
|