// Adapted from turboderp exllama: https://github.com/turboderp/exllama #include "column_remap.cuh" #include "util.cuh" const int SHUF_BLOCKSIZE_X = 256; const int SHUF_BLOCKSIZE_Y = 16; __global__ void column_remap_kernel ( const half* __restrict__ x, half* __restrict__ x_new, const int x_width, const int x_height, const uint32_t* x_map ) { int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; if (x_column >= x_width) return; //if (x_row >= x_height) return; int x_stride = x_width; int x_idx = x_row * x_stride + x_column; int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); int x_idx_end = x_row_end * x_stride + x_column; int s_column = x_map[x_column]; int s_idx = x_row * x_stride + s_column; while (x_idx < x_idx_end) { x_new[x_idx] = x[s_idx]; x_idx += x_stride; s_idx += x_stride; } } // Remap columns in x to correspond to sequential group index before matmul // // perform x -> seq_x such that seq_x @ seq_w == x @ w void column_remap_cuda ( const half* x, half* x_new, const int x_height, const int x_width, const uint32_t* x_map ) { dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); dim3 blocks ( (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, 1 ); column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); }