[doc] add deepspeed citation and copyright (#2996)

* [doc] add deepspeed citation and copyright

* [doc] add deepspeed citation and copyright

* [doc] add deepspeed citation and copyright
pull/2999/head
ver217 2023-03-04 20:08:11 +08:00 committed by GitHub
parent e0a1c1321c
commit 823f3b9cf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 336 additions and 289 deletions

View File

@ -1,16 +1,16 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from typing import List, Iterable from typing import Iterable, List, Optional, Type
from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.logging import get_dist_logger
from torch import Tensor
from colossalai.gemini.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from typing import Optional, Type
from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.engine.gradient_handler import BaseGradientHandler
from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
from colossalai.gemini.ophooks import BaseOpHook, register_ophooks_recursively
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
@ -93,7 +93,7 @@ class Engine:
if self.uses_pipeline: if self.uses_pipeline:
self._schedule.pre_processing(self) self._schedule.pre_processing(self)
#register hook if any # register hook if any
if len(self._ophook_list) > 0: if len(self._ophook_list) > 0:
register_ophooks_recursively(self._model, self._ophook_list) register_ophooks_recursively(self._model, self._ophook_list)

View File

@ -1,7 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import operator import operator
from functools import reduce from functools import reduce
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union
import torch import torch
from ..registry import meta_profiler_function from ..registry import meta_profiler_function

View File

@ -1,8 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import math
import operator import operator
from functools import reduce from functools import reduce
import math
from typing import Tuple from typing import Tuple
import torch import torch
from ..registry import meta_profiler_module from ..registry import meta_profiler_module

View File

@ -1,5 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Tuple, Union from typing import Tuple, Union
import torch import torch
from ..registry import meta_profiler_module from ..registry import meta_profiler_module

View File

@ -1,7 +1,7 @@
import torch # this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from typing import List, Callable, Optional
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List, Optional
import torch import torch

View File

@ -1,6 +1,7 @@
/* Copyright 2021 The LightSeq Team /* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/ */
#include "cublas_wrappers.h" #include "cublas_wrappers.h"

View File

@ -1,6 +1,7 @@
/* Copyright 2021 The LightSeq Team /* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/ */
#pragma once #pragma once

View File

@ -1,68 +1,69 @@
#pragma once #pragma once
/* Copyright 2021 The LightSeq Team /* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed This file is adapted from Microsoft DeepSpeed
*/ Licensed under the MIT License.
#include <cuda.h> */
#include <cuda_fp16.h> #include <cuda.h>
#include <stdio.h> #include <cuda_fp16.h>
#include <stdio.h>
#include <array>
#include <array>
#include "cublas_wrappers.h"
#include "kernels.h" #include "cublas_wrappers.h"
#include "kernels.h"
template <typename T>
class FeedForward { template <typename T>
public: class FeedForward {
struct Config { public:
int outputSize; struct Config {
int inputSize; int outputSize;
std::array<int, 3> gemm_algos; int inputSize;
Config(int outputs, int inputs) std::array<int, 3> gemm_algos;
: outputSize(outputs), Config(int outputs, int inputs)
inputSize(inputs), : outputSize(outputs),
gemm_algos(std::array<int, 3>({99, 99, 99})) {} inputSize(inputs),
}; gemm_algos(std::array<int, 3>({99, 99, 99})) {}
};
FeedForward(Config config) : config_(config) {}
FeedForward(Config config) : config_(config) {}
~FeedForward() {}
~FeedForward() {}
void Forward(int bsz, const T *input_ptr, const T *weights, T *out,
cublasHandle_t &_cublasHandle) { void Forward(int bsz, const T *input_ptr, const T *weights, T *out,
float alpha = T(1.); cublasHandle_t &_cublasHandle) {
float beta = T(0.); float alpha = T(1.);
float beta = T(0.);
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize,
bsz, config_.inputSize, &alpha, &beta, weights, input_ptr, cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize,
out, cublasGemmAlgo_t(config_.gemm_algos[0])); bsz, config_.inputSize, &alpha, &beta, weights, input_ptr,
} out, cublasGemmAlgo_t(config_.gemm_algos[0]));
void Backward(int bsz, const T *out_grad, const T *input_ptr, }
const T *weights, T *weights_grad, T *bias_grad, void Backward(int bsz, const T *out_grad, const T *input_ptr,
cublasHandle_t &_cublasHandle, cudaStream_t &stream, const T *weights, T *weights_grad, T *bias_grad,
T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr, cublasHandle_t &_cublasHandle, cudaStream_t &stream,
bool compute_bias = true) { T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr,
float alpha = (T)1.0, beta = (T)0.0; bool compute_bias = true) {
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize, float alpha = (T)1.0, beta = (T)0.0;
config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad, cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize,
weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1])); config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad,
weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1]));
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize,
bsz, config_.outputSize, &alpha, &beta, weights, out_grad, cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize,
inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2])); bsz, config_.outputSize, &alpha, &beta, weights, out_grad,
if (compute_bias) { inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2]));
launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz, if (compute_bias) {
config_.outputSize, stream); launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz,
} config_.outputSize, stream);
} }
}
void reset_size(int outputSize, int inputSize) {
config_.outputSize = outputSize; void reset_size(int outputSize, int inputSize) {
config_.inputSize = inputSize; config_.outputSize = outputSize;
} config_.inputSize = inputSize;
}
private:
Config config_; private:
}; Config config_;
};

View File

@ -1,99 +1,100 @@
/* Copyright 2021 The LightSeq Team /* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed This file is adapted from Microsoft DeepSpeed
*/ Licensed under the MIT License.
#pragma once */
#pragma once
#include <cuda.h>
#include <cuda_fp16.h> #include <cuda.h>
#include <stdio.h> #include <cuda_fp16.h>
#include <stdio.h>
#include <array>
#include <array>
#include "cublas_wrappers.h"
#include "cublas_wrappers.h"
template <typename T>
class StridedBatchGemm { template <typename T>
public: class StridedBatchGemm {
struct Config { public:
int m; struct Config {
int n; int m;
int k; int n;
float alpha; int k;
float beta; float alpha;
cublasOperation_t op_A; float beta;
cublasOperation_t op_B; cublasOperation_t op_A;
std::array<int, 3> gemm_algos; cublasOperation_t op_B;
std::array<int, 3> gemm_algos;
Config(float param_alpha, float param_beta, cublasOperation_t opA,
cublasOperation_t opB) Config(float param_alpha, float param_beta, cublasOperation_t opA,
: alpha(param_alpha), cublasOperation_t opB)
beta(param_beta), : alpha(param_alpha),
op_A(opA), beta(param_beta),
op_B(opB), op_A(opA),
gemm_algos(std::array<int, 3>({99, 99, 99})) {} op_B(opB),
void SetConfig(int mm, int nn, int kk) { gemm_algos(std::array<int, 3>({99, 99, 99})) {}
m = mm; void SetConfig(int mm, int nn, int kk) {
n = nn; m = mm;
k = kk; n = nn;
} k = kk;
}; }
};
StridedBatchGemm(const Config &config) : _config(config) {}
StridedBatchGemm(const Config &config) : _config(config) {}
virtual ~StridedBatchGemm() {}
virtual ~StridedBatchGemm() {}
void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b,
cublasHandle_t handle) { void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b,
int stride_a = _config.m * _config.k; cublasHandle_t handle) {
int stride_b = _config.n * _config.k; int stride_a = _config.m * _config.k;
int stride_c = _config.m * _config.n; int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(
handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta, cublas_strided_batched_gemm(
_buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a, handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta,
stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0])); _buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a,
} stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0]));
}
void Backward(int bsz, const T *d_output, const T *_buffer_a,
const T *_buffer_b, cublasHandle_t handle, void Backward(int bsz, const T *d_output, const T *_buffer_a,
T *inpGradA = nullptr, T *inpGradB = nullptr) { const T *_buffer_b, cublasHandle_t handle,
int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); T *inpGradA = nullptr, T *inpGradB = nullptr) {
int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
int stride_a = mb * _config.n;
int stride_b = _config.n * kb; int stride_a = mb * _config.n;
int stride_c = _config.m * _config.k; int stride_b = _config.n * kb;
int stride_c = _config.m * _config.k;
// B need to transpose.
cublasOperation_t op_b = // B need to transpose.
(_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); cublasOperation_t op_b =
(_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
// Calculate d_A.
cublas_strided_batched_gemm( // Calculate d_A.
handle, mb, kb, _config.n, &_config.alpha, &_config.beta, cublas_strided_batched_gemm(
(_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), handle, mb, kb, _config.n, &_config.alpha, &_config.beta,
(_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA, (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz, (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA,
cublasGemmAlgo_t(_config.gemm_algos[1])); CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz,
cublasGemmAlgo_t(_config.gemm_algos[1]));
// A need to transpose.
cublasOperation_t op_a = // A need to transpose.
(_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); cublasOperation_t op_a =
(_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
stride_a = _config.m * _config.k;
stride_b = _config.m * _config.n; stride_a = _config.m * _config.k;
stride_c = _config.n * _config.k; stride_b = _config.m * _config.n;
stride_c = _config.n * _config.k;
// Calculate d_B.
cublas_strided_batched_gemm( // Calculate d_B.
handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta, cublas_strided_batched_gemm(
_buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b, handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta,
stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2])); _buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b,
} stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2]));
}
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config; private:
}; Config _config;
};

View File

@ -1,5 +1,10 @@
// modified from // modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu // https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>

View File

@ -1,12 +1,18 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh // modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include "compat.h"
#include <assert.h> #include <assert.h>
#include <c10/cuda/CUDAGuard.h>
#include "compat.h"
// #include <iostream> // #include <iostream>
@ -17,117 +23,108 @@ constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template <int n> template <int n>
struct TensorListMetadata struct TensorListMetadata {
{ void *addresses[n][depth_to_max_tensors[n - 1]];
void *addresses[n][depth_to_max_tensors[n - 1]]; int sizes[depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]]; unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. // full int.
int start_tensor_this_launch; int start_tensor_this_launch;
}; };
template <typename T, typename U, typename... ArgTypes> template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel( __global__ void multi_tensor_apply_kernel(int chunk_size,
int chunk_size, volatile int *noop_flag, T tl,
volatile int *noop_flag, U callable, ArgTypes... args) {
T tl, // Hand the chunk information to the user-supplied functor to process however
U callable, // it likes.
ArgTypes... args) callable(chunk_size, noop_flag, tl, args...);
{
// Hand the chunk information to the user-supplied functor to process however it likes.
callable(chunk_size, noop_flag, tl, args...);
} }
template <int depth, typename T, typename... ArgTypes> template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply( void multi_tensor_apply(
int block_size, int block_size, int chunk_size, const at::Tensor &noop_flag,
int chunk_size, const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
const at::Tensor &noop_flag, ArgTypes... args) {
const std::vector<std::vector<at::Tensor>> &tensor_lists, TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
T callable, int len0 = tensor_lists[0].size();
ArgTypes... args) TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
{ auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
int len0 = tensor_lists[0].size(); for (int l = 0; l < tensor_lists.size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); l++) // No range-based for because I need indices
auto ref_device = tensor_lists[0][0].device(); {
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); TORCH_CHECK(tensor_lists[l].size() == len0,
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices "Size mismatch among tensor lists");
{ for (int t = 0; t < tensor_lists[l].size(); t++) {
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); // TODO: Print which tensor fails.
for (int t = 0; t < tensor_lists[l].size(); t++) bool contiguous_memory = tensor_lists[l][t].is_contiguous();
{
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5 #ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); contiguous_memory =
(contiguous_memory ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif #endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor"); TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); "A tensor was not on the same device as the first tensor");
} TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(),
"Size mismatch");
} }
}
int ntensors = tensor_lists[0].size(); int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl; TensorListMetadata<depth> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0; tl.start_tensor_this_launch = 0;
int loc_block_info = 0; int loc_block_info = 0;
int loc_tensor_info = 0; int loc_tensor_info = 0;
for (int t = 0; t < ntensors; t++) for (int t = 0; t < ntensors; t++) {
{ tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); for (int d = 0; d < depth; d++)
for (int d = 0; d < depth; d++) tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); loc_tensor_info++;
loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; int chunks_this_tensor =
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
{ // std::cout << chunks_this_tensor << std::endl;
// std::cout << chunks_this_tensor << std::endl; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; tl.block_to_chunk[loc_block_info] = chunk;
tl.block_to_chunk[loc_block_info] = chunk; loc_block_info++;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks_this_tensor - 1); chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) if (tensors_full || blocks_full || last_chunk) {
{ // using accscalar_t = acc_type<scalar_t, true>;
// using accscalar_t = acc_type<scalar_t, true>; multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>( chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
chunk_size,
noop_flag.DATA_PTR<int>(),
tl,
callable,
args...);
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt. // Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0; loc_block_info = 0;
if (chunk == chunks_this_tensor - 1) if (chunk == chunks_this_tensor - 1) {
{ // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl; // << std::endl;
loc_tensor_info = 0; loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1; tl.start_tensor_this_launch = t + 1;
} } else {
else // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3
{ // << std::endl;
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl; tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; for (int d = 0; d < depth; d++)
for (int d = 0; d < depth; d++) tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; loc_tensor_info = 1;
loc_tensor_info = 1; tl.start_tensor_this_launch = t;
tl.start_tensor_this_launch = t;
}
}
} }
}
} }
} }
}

View File

@ -1,4 +1,9 @@
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */ /* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.h" #include "compat.h"

View File

@ -1,4 +1,11 @@
# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py # modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py
'''
Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
'''
import torch import torch
from colossalai.registry import OPTIMIZERS from colossalai.registry import OPTIMIZERS

View File

@ -1,3 +1,4 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import math import math
import warnings import warnings
from enum import Enum from enum import Enum

View File

@ -1,7 +1,12 @@
import torch # This code has been adapted from the DeepSpeed library.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import functools import functools
from typing import Optional from typing import Optional
import torch
def substitute_init_recursively(cls, func, visited: set): def substitute_init_recursively(cls, func, visited: set):
for subcls in cls.__subclasses__(): for subcls in cls.__subclasses__():

View File

@ -1,3 +1,4 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import functools import functools
import itertools import itertools
from collections import OrderedDict from collections import OrderedDict

View File

@ -1,3 +1,4 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from functools import partial from functools import partial
from typing import Optional from typing import Optional

View File

@ -1,3 +1,4 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from enum import Enum from enum import Enum
from os import stat from os import stat
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
@ -5,20 +6,21 @@ from typing import Dict, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.gemini.stateful_tensor import (StatefulTensor, TensorState)
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
class OptimState(Enum): class OptimState(Enum):
@ -36,9 +38,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory, GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
which is detected by a runtime memory tracer. which is detected by a runtime memory tracer.
We place as many OS chunks in the margin space as possible. We place as many OS chunks in the margin space as possible.
The size of margin space can be controlled by ``gpu_margin_mem_ratio``. The size of margin space can be controlled by ``gpu_margin_mem_ratio``.
If it is set as ``0.0``, it is the same as classical ZeRO optimizer. If it is set as ``0.0``, it is the same as classical ZeRO optimizer.
@ -54,8 +56,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors. shard strategy provided by sharded model to shard param fp32 tensors.
optimizer (Optimizer): An Optimizer instance. optimizer (Optimizer): An Optimizer instance.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer. which will be used when using hybrid CPU optimizer.
This argument is meaningless when `tensor_placement_policy` of `ShardedModelV2` is not "auto". This argument is meaningless when `tensor_placement_policy` of `ShardedModelV2` is not "auto".
Defaults to 0.0. Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.

View File

@ -1,3 +1,7 @@
# This code has been adapted from the DeepSpeed library.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import importlib import importlib
import os import os
import time import time