[hotfix] fix build error when torch version >= 1.13 (#1803)

pull/1809/head^2
xcnick 2022-11-08 09:40:24 +08:00 committed by GitHub
parent f5a92c288c
commit e0da01ea71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 1 deletions

View File

@ -2,8 +2,13 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/torch.h>
#if TORCH_VERSION_MINOR >= 13
#include <torch/csrc/distributed/c10d/Types.hpp>
#else
#include <c10d/Types.hpp> #include <c10d/Types.hpp>
#endif
#include <iostream> #include <iostream>
#include "context.h" #include "context.h"

View File

@ -4,8 +4,14 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <torch/torch.h>
#if TORCH_VERSION_MINOR >= 13
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#else
#include <c10d/ProcessGroup.hpp> #include <c10d/ProcessGroup.hpp>
#endif
#include <string> #include <string>
#include <type_traits> #include <type_traits>
@ -157,4 +163,4 @@ class MultiHeadAttention {
c10::intrusive_ptr<c10d::ProcessGroup> pg; c10::intrusive_ptr<c10d::ProcessGroup> pg;
int pg_size; int pg_size;
}; };