/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */ /* Copyright 2020 The Microsoft DeepSpeed Team Copyright NVIDIA/apex This file is adapted from fused adam in NVIDIA/apex, commit a109f85 Licensed under the MIT License. */ #pragma once #include #ifndef TORCH_CHECK #define TORCH_CHECK AT_CHECK #endif #ifdef VERSION_GE_1_3 #define DATA_PTR data_ptr #else #define DATA_PTR data #endif #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Half: { \ using scalar_t = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: { \ using scalar_t = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Float: { \ using scalar_t = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: { \ using scalar_t = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: { \ using scalar_t = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \ TYPE, NAME, ...) \ switch (HIGH_PRECISION) { \ case false: { \ const bool high_precision = false; \ DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ break; \ } \ case true: { \ const bool high_precision = true; \ DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ break; \ } \ default: \ AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \ } #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ switch (TYPEIN) { \ case at::ScalarType::Float: { \ using scalar_t_in = float; \ switch (TYPEOUT) { \ case at::ScalarType::Float: { \ using scalar_t_out = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: { \ using scalar_t_out = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: { \ using scalar_t_out = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ } \ break; \ } \ case at::ScalarType::Half: { \ using scalar_t_in = at::Half; \ using scalar_t_out = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::BFloat16: { \ using scalar_t_in = at::BFloat16; \ using scalar_t_out = at::BFloat16; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ } // Forward/backward compatiblity hack around // https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 // pending more future-proof guidance from upstream. // struct TypeShim // { // const at::Type& payload; // TypeShim(const at::Type& type) : payload(type) {} // // Enable trivial conversion to a const at::Type& for pre-3aeb78 // operator const at::Type&(){ return payload; }; // // Enable dispatch switch statements to take *this directly for post-3aeb78 // //operator at::ScalarType(){ return payload.; }; // }; #define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Float: { \ using scalar_t_##LEVEL = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: { \ using scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Float: { \ using scalar_t_##LEVEL = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: { \ using scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Byte: { \ using scalar_t_##LEVEL = uint8_t; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Double: { \ using scalar_t_##LEVEL = double; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Float: { \ using scalar_t_##LEVEL = float; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Half: { \ using scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Double: { \ using scalar_t_##LEVEL = double; \ __VA_ARGS__; \ break; \ } \ case at::ScalarType::Float: { \ using scalar_t_##LEVEL = float; \ __VA_ARGS__; \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } #define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \ if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \ using g_scalar_t_##LEVEL = float; \ using p_scalar_t_##LEVEL = float; \ __VA_ARGS__; \ } else if (GTYPE == at::ScalarType::Float && \ PTYPE == at::ScalarType::Half) { \ using g_scalar_t_##LEVEL = float; \ using p_scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ } else if (GTYPE == at::ScalarType::Half && \ PTYPE == at::ScalarType::Float) { \ using g_scalar_t_##LEVEL = at::Half; \ using p_scalar_t_##LEVEL = float; \ __VA_ARGS__; \ } else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \ using g_scalar_t_##LEVEL = at::Half; \ using p_scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ } else if (GTYPE == at::ScalarType::Float && \ PTYPE == at::ScalarType::BFloat16) { \ using g_scalar_t_##LEVEL = float; \ using p_scalar_t_##LEVEL = at::BFloat16; \ __VA_ARGS__; \ } else if (GTYPE == at::ScalarType::BFloat16 && \ PTYPE == at::ScalarType::Float) { \ using g_scalar_t_##LEVEL = at::BFloat16; \ using p_scalar_t_##LEVEL = float; \ __VA_ARGS__; \ } else if (GTYPE == at::ScalarType::BFloat16 && \ PTYPE == at::ScalarType::BFloat16) { \ using g_scalar_t_##LEVEL = at::BFloat16; \ using p_scalar_t_##LEVEL = at::BFloat16; \ __VA_ARGS__; \ } else { \ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ "'"); \ }