diff --git a/hopper/CMakeLists.txt b/hopper/CMakeLists.txt new file mode 100644 index 0000000..9adfb68 --- /dev/null +++ b/hopper/CMakeLists.txt @@ -0,0 +1,42 @@ +cmake_minimum_required(VERSION 3.22.1) +project(flashdecoding CUDA CXX) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_ARCHITECTURES 90) + +set(INCLUDE_DIR /home/zhichen/dayou/BitAttn/3rdparty/cutlass/include) + +# Enable ccache if available +find_program(CCACHE_PROGRAM ccache) +if(CCACHE_PROGRAM) + set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") +endif() + + +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +message(STATUS "Compile testing packdecode kernel.") +add_executable(test_single_decode + ${PROJECT_SOURCE_DIR}/src/test_single_decode.cu + ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_hdim128_fp16_split_sm90.cu + ${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu + ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_combine.cu +) +target_link_libraries(test_single_decode "${TORCH_LIBRARIES}") +target_include_directories(test_single_decode PRIVATE ${INCLUDE_DIR}) +target_compile_options(test_single_decode PRIVATE $<$:-maxrregcount=255 -gencode arch=compute_90a,code=sm_90a -w>) + +message(STATUS "Compile benchmark single decode kernel.") +add_executable(bench_single_decode + ${PROJECT_SOURCE_DIR}/src/bench_single_decode.cu + ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_hdim128_fp16_split_sm90.cu + ${PROJECT_SOURCE_DIR}/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu + ${PROJECT_SOURCE_DIR}/src/genfile/flash_fwd_combine.cu +) +target_link_libraries(bench_single_decode "${TORCH_LIBRARIES}") +target_include_directories(bench_single_decode PRIVATE ${INCLUDE_DIR}) +target_compile_options(bench_single_decode PRIVATE $<$:-maxrregcount=255 -gencode arch=compute_90a,code=sm_90a -w -O3 -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED -DNDEBUG>) diff --git a/hopper/src/bench_single_decode.cu b/hopper/src/bench_single_decode.cu new file mode 100644 index 0000000..70435f0 --- /dev/null +++ b/hopper/src/bench_single_decode.cu @@ -0,0 +1,220 @@ +#include +#include + +#include "flash_api.h" + +torch::Tensor single_mha(torch::Tensor& q, torch::Tensor& k, torch::Tensor& v, int head_dim) { + const float sm_scale = 1.f / std::sqrt(float(head_dim)); + auto scaled_q = q * sm_scale; + + auto scores = torch::einsum("bthd,bshd->bhts", {scaled_q, k}); + auto attention = torch::softmax(scores, -1).to(v.dtype()); + auto output = torch::einsum("bhts,bshd->bthd", {attention, v}); + return output; +} + +template +double TestDecodingKernelPerformance(int seqlen_kv, int bs, const std::string quant_mode, const int group_size, const int repeat, const int num_splits=0) { + const int seqlen_q = 4; + const int pack_nums = 16 / num_bits; + + torch::Tensor Q_host = torch::rand({bs, seqlen_q, num_heads, head_dim}, torch::dtype(torch::kHalf)); + torch::Tensor K_host = torch::ones({bs, seqlen_kv, num_heads_kv, head_dim}, torch::dtype(torch::kHalf)); + torch::Tensor V_host = torch::ones({bs, seqlen_kv, num_heads_kv, head_dim}, torch::dtype(torch::kHalf)); + + torch::Tensor Q_device = Q_host.to(torch::kCUDA); + // torch::Tensor K_device = K_host.to(torch::kCUDA); + // torch::Tensor V_device = V_host.to(torch::kCUDA); + + at::Tensor k_pack, k_params, v_pack, v_params; + if (quant_mode == "k-channel") { + k_pack = torch::empty({bs, seqlen_kv / pack_nums, num_heads_kv, head_dim}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); + k_params = torch::empty({bs, seqlen_kv / group_size, num_heads_kv, head_dim}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); + } else { + k_pack = torch::empty({bs, seqlen_kv, num_heads_kv, head_dim / pack_nums}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); + k_params = torch::empty({bs, head_dim / group_size, num_heads_kv, seqlen_kv}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); + } + v_pack = torch::empty({bs, seqlen_kv, num_heads_kv, head_dim / pack_nums}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); + v_params = torch::empty({bs, head_dim / group_size, num_heads_kv, seqlen_kv}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); + + // Convert K, V to unpadded format + // torch::Tensor K_unpad = K_device.reshape({bs * seqlen_kv, num_heads_kv, head_dim}); + // torch::Tensor V_unpad = V_device.reshape({bs * seqlen_kv, num_heads_kv, head_dim}); + + // auto cu_seqlens_k = torch::arange(0, (bs + 1) * seqlen_kv, seqlen_kv, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); + // std::optional opt_block_table = std::nullopt; + + // kvcache_qpack( + // K_unpad, k_pack, k_params, + // V_unpad, v_pack, v_params, + // opt_block_table, + // cu_seqlens_k, + // seqlen_kv, + // quant_mode, + // group_size + // ); + + const float sm_scale = 1 / std::sqrt(float(head_dim)); + + // Warm up + for (int i = 1; i < 5; ++i) + mha_fwd_kvcache(Q_device, + K_host, k_pack, k_params, + V_host, v_pack, v_params, + sm_scale); + + // Benchmark + cudaEvent_t start, end; + cudaEventCreate(&start); + cudaEventCreate(&end); + cudaEventRecord(start); + for (int i = 0; i < repeat; i++) { + // sm_scale = 1 / std::sqrt(float(head_dim)); + mha_fwd_kvcache(Q_device, + K_host, k_pack, k_params, + V_host, v_pack, v_params, + sm_scale, + num_splits); + } + cudaEventRecord(end); + cudaEventSynchronize(end); + + float msec, sec; + cudaEventElapsedTime(&msec, start, end); + msec = msec / repeat; + + return msec; +} + +int main() { + const int num_heads = 128; + const int num_heads_kv = 32; + const int head_dim = 128; + const int num_bits = 4; + const std::string quant_mode = "k-tensor"; + const int group_size = 128; + const int test_num = 10; + + int len_list[test_num]; + len_list[0] = 1024; + for (int i = 1; i < test_num; i++) { + len_list[i] = len_list[i - 1] * 2; + } + + int bs_list[7]; + bs_list[0] = 2; + for (int i = 1; i < 7; i++) { + bs_list[i] = bs_list[i - 1] * 2; + } + + const int outer_repeat = 1, inner_repeat = 1; + + // printf("\n######## Benchmark single decode ########\n"); + // for (int j = 0; j < test_num; j++) { + + // int seqlen_kv = len_list[j]; + // double max_msec = 0.0; + // double min_msec = DBL_MAX; + // double total_msec = 0.0; + + // for (int k = 0; k < outer_repeat; k++) { + // double this_sec = TestDecodingKernelPerformance(seqlen_kv, quant_mode, group_size, inner_repeat); + // max_msec = max(max_msec, this_sec); + // min_msec = min(min_msec, this_sec); + // total_msec += this_sec; + // } + + // double avg_msec = total_msec / outer_repeat; + // printf("seqlen_kv num_heads head_dim = %6d %6d %6d, ", seqlen_kv, num_heads, head_dim); + // printf("Time = %12.8lf %12.8lf %12.8lf ms, \n", min_msec, avg_msec, max_msec); + // } + + printf("\n######## Benchmark single decode with different num_splits ########\n"); + for (int j = 0; j < test_num; j++) { + int bs = 1; + int seqlen_kv = len_list[j]; + double best_time = DBL_MAX; + int best_splits = 2; + + printf("\nTesting seqlen_kv=%d:\n", seqlen_kv); + printf("num_splits min_time(ms) avg_time(ms) max_time(ms)\n"); + printf("------------------------------------------------\n"); + + // Test different num_splits values + for (int splits = 0; splits <= 20; splits++) { + double max_msec = 0.0; + double min_msec = DBL_MAX; + double total_msec = 0.0; + + for (int k = 0; k < outer_repeat; k++) { + double this_sec = TestDecodingKernelPerformance( + seqlen_kv, bs, quant_mode, group_size, inner_repeat, splits); + max_msec = max(max_msec, this_sec); + min_msec = min(min_msec, this_sec); + total_msec += this_sec; + } + + double avg_msec = total_msec / outer_repeat; + printf("%9d %11.4f %11.4f %11.4f\n", splits, min_msec, avg_msec, max_msec); + + if (min_msec < best_time) { + best_time = min_msec; + best_splits = splits; + } + if (j < 2) { + break; + } else if (j < 5 && splits > 5) { + break; + } + } + + printf("\nBest result for seqlen_kv=%d: num_splits=%d, time=%.4f ms\n", + seqlen_kv, best_splits, best_time); + } + + // printf("\n######## Benchmark single decode with different num_splits ########\n"); + // for (int j = 0; j < 7; j++) { + // int bs = bs_list[j]; + // int seqlen_kv = 32768; + // double best_time = DBL_MAX; + // int best_splits = 0; + + // printf("\nTesting batch_size=%d:\n", bs); + // printf("num_splits min_time(ms) avg_time(ms) max_time(ms)\n"); + // printf("------------------------------------------------\n"); + + // // Test different num_splits values + // for (int splits = 0; splits <= 10; splits++) { + // double max_msec = 0.0; + // double min_msec = DBL_MAX; + // double total_msec = 0.0; + + // for (int k = 0; k < outer_repeat; k++) { + // double this_sec = TestDecodingKernelPerformance( + // seqlen_kv, bs, quant_mode, group_size, inner_repeat, splits); + // max_msec = max(max_msec, this_sec); + // min_msec = min(min_msec, this_sec); + // total_msec += this_sec; + // } + + // double avg_msec = total_msec / outer_repeat; + // printf("%9d %11.4f %11.4f %11.4f\n", splits, min_msec, avg_msec, max_msec); + + // if (min_msec < best_time) { + // best_time = min_msec; + // best_splits = splits; + // } + + // if (j < 2) { + // break; + // } else if (j < 5 && splits > 5) { + // break; + // } + // } + + // printf("\nBest result for seqlen_kv=%d: num_splits=%d, time=%.4f ms\n", + // seqlen_kv, best_splits, best_time); + // } + + return 0; +} \ No newline at end of file diff --git a/hopper/src/flash_api.h b/hopper/src/flash_api.h new file mode 100644 index 0000000..224e2b2 --- /dev/null +++ b/hopper/src/flash_api.h @@ -0,0 +1,568 @@ +#pragma once + +#include +#include +#include + +#include + +#include "include/flash.h" +#include "include/heuristics.h" +#include "include/tile_size.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +inline int round_up_headdim(int head_size) { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (head_size <= 64) { return 64; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (head_size <= 96) { return 96; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (head_size <= 128) { return 128; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (head_size <= 192) { return 192; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (head_size <= 256) { return 256; } + #endif + return 256; +} + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, const size_t seqlen_k_pack, const size_t seqlen_k_params, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, const size_t d_kpack, const size_t d_vpack, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, const at::Tensor k_pack, const at::Tensor k_params, + const at::Tensor v, const at::Tensor v_pack, const at::Tensor v_params, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + const float softcap=0.f, + const int sm_margin=0) { + + // Reset the parameters + params = {}; + + params.is_bf16 = q.dtype() == torch::kBFloat16; + params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.K_pack_ptr = k_pack.data_ptr(); + params.k_params_ptr = k_params.data_ptr(); + params.v_ptr = v.data_ptr(); + params.v_pack_ptr = v_pack.data_ptr(); + params.v_params_ptr = v_params.data_ptr(); + + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.K_pack_row_stride = k_pack.stride(-3); + params.k_params_row_stride = k_params.stride(-1); + params.v_row_stride = v.stride(-3); + params.v_pack_row_stride = v_pack.stride(-3); + params.v_params_row_stride = v_params.stride(-1); + + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.K_pack_head_stride = k_pack.stride(-2); + params.k_params_head_stride = k_params.stride(-2); + params.v_head_stride = v.stride(-2); + params.v_pack_head_stride = v_pack.stride(-2); + params.v_params_head_stride = v_params.stride(-2); + + params.v_dim_stride = v.stride(-1); + params.k_params_dim_stride = k_params.stride(-3); + params.v_params_dim_stride = v_params.stride(-3); + + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.o_batch_stride = out.stride(0); + } + if (cu_seqlens_k_d == nullptr) { + params.k_batch_stride = k.stride(0); + params.K_pack_batch_stride = k_pack.stride(0); + params.k_params_batch_stride = k_params.stride(0); + params.v_batch_stride = v.stride(0); + params.v_pack_batch_stride = v_pack.stride(0); + params.v_params_batch_stride = v_params.stride(0); + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_q = static_cast(seqused_q); + params.seqused_k = static_cast(seqused_k); + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_k_pack = seqlen_k_pack; + params.seqlen_k_params = seqlen_k_params; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_kpack = d_kpack; + params.d_vpack = d_vpack; + params.d_vparams = v_params.size(1); + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.softcap = softcap; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; + + // TODO: check this + if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; } + if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; + params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; + + #ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif +} + +inline int get_num_splits(Flash_fwd_params const& params) { + #ifdef FLASHATTENTION_DISABLE_SPLIT + return 1; + #else + // Always enable PackGQA for Split + // params.page_table must already be set + // This needs to match the kernel configs + bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits + // has not been set here. It's OK though because we might just underestimate kBlockN a bit + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); + // If is_local, we're not going to load all of seqlen_k + int const seqlen_k_loaded = !params.is_local + ? params.seqlen_k + : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); + int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; + int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; + return num_splits_heuristic(params.b * (!params.pack_gqa ? params.h : params.h_k) * num_m_blocks, params.num_sm, num_n_blocks, 128); + // return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.b * params.h_k, + // params.num_sm, num_n_blocks, 128, params.d_rounded); + #endif +} + +inline bool get_pack_gqa(Flash_fwd_params const& params) { + // Always enable PackGQA for Sm8x or PagedKV or Split to reduce compilation and binary size. + // Has little effect on speed. + if (params.arch < 90 || params.page_table || params.num_splits > 1) { return true; } + #ifdef FLASHATTENTION_DISABLE_PACKGQA + return false; + #else + // params.page_table must already be set + if (params.h == params.h_k) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); + #endif +} + + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + return run_mha_fwd_<90, cutlass::half_t, 128, true, false, false, true>(params, stream); +} + +void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { + return run_mha_fwd_combine_(params, stream); +} + +template +void run_kvcache_qpack(Flash_fwd_params ¶ms, cudaStream_t stream) { + if (params.quant_mode == "k-channel") { + if (params.group_size == 32) { + // run_kvcache_qpack_(params, stream); + } else if (params.group_size == 64) { + // run_kvcache_qpack_(params, stream); + } else if (params.group_size == 128) { + // run_kvcache_qpack_(params, stream); + } + } else { + if (params.group_size == 32) { + // run_kvcache_qpack_(params, stream); + } else if (params.group_size == 64) { + // run_kvcache_qpack_(params, stream); + } else if (params.group_size == 128) { + run_kvcache_qpack_(params, stream); + } + } +} + +template +at::Tensor +mha_fwd_kvcache(at::Tensor &q, + const at::Tensor &k, const at::Tensor &k_pack, const at::Tensor &k_params, + const at::Tensor &v, const at::Tensor &v_pack, const at::Tensor &v_params, + const float softmax_scale, + const int num_splits=0) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major >= 8; + TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_type = q.scalar_type(); + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + if (dprops->major < 9) { + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type"); + } + TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + + CHECK_DEVICE(q); // CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + at::Tensor page_table; + const bool paged_KV = false; + at::Tensor cu_seqlens_q; + bool const is_varlen_q = false; + at::Tensor cu_seqlens_k; + bool const is_varlen_k = false; + bool const is_varlen = false; + + auto const sizes = q.sizes(); + const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; + int seqlen_q = sizes[1]; + int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; + int num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_kpack = k_pack.size(-1); + int const head_size_vpack = v_pack.size(-1); + int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); + int const num_pages = !paged_KV ? 0 : k.size(0); + int const page_size = !paged_KV ? 1 : k.size(1); + int const seqlen_k = !paged_KV ? k.size(1) : max_num_pages_per_seq * page_size; + int const seqlen_k_pack = k_pack.size(1); + int const seqlen_k_params = k_params.size(-1); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + int const batch_size_k = k.size(0); + + int window_size_left = -1; + int window_size_right = -1; + + bool is_causal = window_size_left < 0 && window_size_right == 0; + + int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; + + auto opts = q.options(); + at::Tensor out; + out = torch::empty_like(q); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const head_size_rounded = round_up_headdim(head_size); + int const seqlen_q_rounded = round_multiple(seqlen_q, 128); + int const seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + at::Tensor softmax_lse; + if (!is_varlen_q) { + softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + } else { + softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, seqlen_k_pack, seqlen_k_params, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_kpack, head_size_vpack, + head_size_rounded, + q, + k, k_pack, k_params, + v, v_pack, v_params, + out, + !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), + !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), + nullptr, + nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right); + params.total_q = total_q; + params.total_k = total_k; + params.sink_token_length = 0; + params.b_k = batch_size_k; + + params.page_size = page_size; + params.num_pages = num_pages; + + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + params.pack_gqa = get_pack_gqa(params); + if (params.num_splits == 1) { + params.num_splits = 2; + } + // printf("num_splits: %d\n", params.num_splits); + // printf("pack_gqa: %d\n", params.pack_gqa); + + params.rotary_dim = 0; + at::Tensor out_accum, softmax_lse_accum; + auto outaccum_type = at::ScalarType::Float; + if (params.num_splits > 1) { + TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); + if (!is_varlen_q) { + out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(outaccum_type)); + softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + params.oaccum_batch_stride = out_accum.stride(1); + params.lseaccum_batch_stride = softmax_lse_accum.stride(1); + } else { + out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size}, opts.dtype(outaccum_type)); + softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat)); + } + params.is_fp32 = false; + params.oaccum_ptr = out_accum.data_ptr(); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_split_stride = out_accum.stride(0); + params.oaccum_row_stride = out_accum.stride(-2); + params.oaccum_head_stride = out_accum.stride(-3); + params.lseaccum_split_stride = softmax_lse_accum.stride(0); + params.lseaccum_head_stride = softmax_lse_accum.stride(-2); + } + + at::Tensor tile_count_semaphore; + // We don't use the persistent scheduler if Split and not Varlen + bool const persistent_scheduler = params.arch >= 90 + ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) + : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); + if (persistent_scheduler) { + tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32)); + params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + } else { + params.tile_count_semaphore = nullptr; + } + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + if (params.num_splits > 1) { + if (is_varlen_q) { + params.b = 1; + params.seqlen_q = total_q; + } + run_mha_fwd_combine(params, stream); + } + + return out; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// QPacking +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +void set_params_fprop_qpack(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_k, + const size_t h, const size_t h_k, + const size_t d, + // device pointers + const at::Tensor k, at::Tensor k_pack, at::Tensor k_params, + const at::Tensor v, at::Tensor v_pack, at::Tensor v_params, + void *cu_seqlens_k_d, + const std::string quant_mode, + const int group_size + ) { + + // Reset the parameters + params = {}; + + params.is_bf16 = k.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.k_ptr = k.data_ptr(); + params.K_pack_ptr = k_pack.data_ptr(); + params.k_params_ptr = k_params.data_ptr(); + params.v_ptr = v.data_ptr(); + params.v_pack_ptr = v_pack.data_ptr(); + params.v_params_ptr = v_params.data_ptr(); + // All stride are in elements, not bytes. + params.k_row_stride = k.stride(-3); + params.K_pack_row_stride = k_pack.stride(-3); + params.k_params_row_stride = k_params.stride(-1); + params.v_row_stride = v.stride(-3); + params.v_pack_row_stride = v_pack.stride(-3); + params.v_params_row_stride = v_params.stride(-1); + + params.k_params_dim_stride = k_params.stride(-3); + params.v_params_dim_stride = v_params.stride(-3); + + params.k_head_stride = k.stride(-2); + params.K_pack_head_stride = k_pack.stride(-2); + params.k_params_head_stride = k_params.stride(-2); + params.v_head_stride = v.stride(-2); + params.v_pack_head_stride = v_pack.stride(-2); + params.v_params_head_stride = v_params.stride(-2); + + // params.k_batch_stride = k.stride(0); + params.k_batch_stride = seqlen_k * k.size(-2) * k.size(-1); + params.K_pack_batch_stride = k_pack.stride(0); + params.k_params_batch_stride = k_params.stride(0); + // params.v_batch_stride = v.stride(0); + params.v_batch_stride = seqlen_k * v.size(-2) * v.size(-1); + params.v_pack_batch_stride = v_pack.stride(0); + params.v_params_batch_stride = v_params.stride(0); + + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_k = seqlen_k; + params.d = d; + + params.quant_mode = quant_mode; + params.group_size = group_size; +} + +template +void kvcache_qpack(const at::Tensor &k, + at::Tensor &k_pack, + at::Tensor &k_params, + const at::Tensor &v, + at::Tensor &v_pack, + at::Tensor &v_params, + c10::optional &block_table_, + const at::Tensor &cu_seqlens_k, + const int max_seqlen_k, + const std::string quant_mode, + const int group_size + ) { + + auto k_dtype = k.dtype(); + TORCH_CHECK(k_dtype == torch::kFloat16 || k_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_k); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_k); + + at::Tensor block_table; + const bool paged_KV = block_table_.has_value(); + if (paged_KV) { + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } + + const auto sizes = k.sizes(); + + const int batch_size = cu_seqlens_k.numel() - 1; + int num_heads = paged_KV ? sizes[2] : sizes[1]; + const int head_size = paged_KV ? sizes[3] : sizes[2]; + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : k.size(0); + const int page_block_size = !paged_KV ? 1 : k.size(1); + const int page_block_size_pack = !paged_KV ? 0 : k_pack.size(1); + const int seqlen_k = !paged_KV ? k.size(1) : max_num_blocks_per_seq * page_block_size; + const int batch_size_c = !paged_KV ? k.size(0) : batch_size; + + TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)k.get_device()}; + + Flash_fwd_params params; + set_params_fprop_qpack(params, + batch_size, + max_seqlen_k, + num_heads, num_heads_k, + head_size, + k, k_pack, k_params, + v, v_pack, v_params, + /*cu_seqlens_k_d=*/nullptr, + quant_mode, + group_size + ); + + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_kvcache_qpack(params, stream); + } + + return; +} \ No newline at end of file diff --git a/hopper/src/flash_fwd_combine_launch_template.h b/hopper/src/flash_fwd_combine_launch_template.h new file mode 100644 index 0000000..6fb0d16 --- /dev/null +++ b/hopper/src/flash_fwd_combine_launch_template.h @@ -0,0 +1,77 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/arch/arch.h" // For cutlass::arch::Sm80 +#include "cutlass/device_kernel.h" // For device_kernel + +#include "include/static_switch.h" +#include "include/flash.h" +#include "include/flash_fwd_combine_kernel.h" + +using namespace cute; + + +template +void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { + using TileShape_MK = cute::Shape, Int>; + using CombineKernel = flash::FlashAttnFwdCombine; + + typename CombineKernel::Arguments args { + static_cast(params.oaccum_ptr), + {!Varlen ? params.seqlen_q : params.total_q, params.d, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial + {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial + static_cast(params.softmax_lseaccum_ptr), + {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial + {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial + static_cast(params.o_ptr), + {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O + static_cast(params.softmax_lse_ptr), + {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE + params.cu_seqlens_q, params.seqused_q + }; + + typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); + int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h * (!Varlen ? params.b : 1), kBlockM); + dim3 grid_m(num_blocks_m, !Varlen ? 1 : params.b); + auto kernel = cutlass::device_kernel; + int smem_size = CombineKernel::SharedStorageSize; + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(kernel_params); + CHECK_CUDA_KERNEL_LAUNCH(); + +} + +template +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(kHeadDim % 32 == 0, "kHeadDim must be a multiple of 32"); + static constexpr int kBlockM = kHeadDim % 128 == 0 ? 8 : (kHeadDim % 64 == 0 ? 16 : 32); + + // BOOL_SWITCH(params.seqused_q != nullptr, Varlen, [&] { + static constexpr bool Varlen = false; + if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. + if (params.num_splits <= 16) { + run_flash_fwd_combine(params, stream); + return; + } + } + if (params.num_splits <= 32) { + run_flash_fwd_combine(params, stream); + } else if (params.num_splits <= 64) { + run_flash_fwd_combine(params, stream); + } else if (params.num_splits <= 128) { + run_flash_fwd_combine(params, stream); + } else { + run_flash_fwd_combine(params, stream); + } + // }); +} \ No newline at end of file diff --git a/hopper/src/flash_fwd_launch_template.h b/hopper/src/flash_fwd_launch_template.h new file mode 100644 index 0000000..f720a1e --- /dev/null +++ b/hopper/src/flash_fwd_launch_template.h @@ -0,0 +1,274 @@ +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" // For device_kernel +#include +#include "cutlass/cluster_launch.hpp" + +#include "include/flash.h" +#include "include/static_switch.h" +#include "include/tile_size.h" +#include "include/flash_fwd_kernel_sm90.h" +#include "include/mainloop_fwd_sm90_tma_gmma_ws.hpp" +#include "include/epilogue_fwd.hpp" +#include "include/tile_scheduler.hpp" +#include "include/kernel_traits.h" +#include "include/flash_qpack_kernel.h" + +#include + +using namespace cute; + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); + static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); + static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen"); + static constexpr bool FP8_TransposeV = false; + + using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; + + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap); + static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); + static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); + static constexpr bool Mma1_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); + // static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap); + static constexpr bool IntraWGOverlap = false; + static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS); + // static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); + static constexpr int kStages = 1; + static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); + + using TileShape_MNK = cute::Shape, Int, Int>; + using ClusterShape = cute::Shape, _1, _1>; + using CollectiveMainloop = flash::CollectiveMainloopFwdSm90; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + + static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; + using SchedulerPersistent = flash::StaticPersistentTileScheduler; + using SchedulerSingleTile = flash::SingleTileScheduler; + using Scheduler = SchedulerSingleTile; + using AttnKernel = flash::enable_sm90_or_later>; + using ElementKVPack = typename CollectiveMainloop::ElementKVPack; + + bool const is_varlen_q = params.cu_seqlens_q; + bool const is_varlen_k = params.cu_seqlens_k; + bool const is_varlen_k_new = params.cu_seqlens_knew; + int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q; + int batch_q = !is_varlen_q ? params.b : 1; + int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1; + int batch_k_pack = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1; + + typename CollectiveMainloop::Arguments mainloop_args { + static_cast(params.q_ptr), + {seqlen_q, params.d, params.h, batch_q}, // shape_Q + {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q + + // static_cast(params.k_ptr), // K_ptr + {!PagedKV ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, + params.d, params.h_k, !PagedKV ? batch_k : params.num_pages}, // shape_K + // {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K + static_cast(params.K_pack_ptr), // K_pack_ptr + {params.seqlen_k_pack, params.d_kpack, params.h_k, batch_k}, // shape_K_pack + {params.K_pack_row_stride, _1{}, params.K_pack_head_stride, params.K_pack_batch_stride}, // stride_K_pack + static_cast<__half2*>(params.k_params_ptr), // K_params_ptr + {params.seqlen_k_params, params.d, params.h_k, batch_k}, // shape_K_params + // {params.k_params_row_stride, _1{}, params.k_params_head_stride, params.k_params_batch_stride}, // stride_K_params + {_1{}, params.k_params_dim_stride, params.k_params_head_stride, params.k_params_batch_stride}, + + // static_cast(params.v_ptr), // V_ptr + // v_strides, // stride_V + static_cast(params.v_pack_ptr), // V_pack_ptr + {params.seqlen_k, params.d_vpack, params.h_k, batch_k}, // shape_V_pack + {params.v_pack_row_stride, _1{}, params.v_pack_head_stride, params.v_pack_batch_stride}, // stride_V_pack + static_cast<__half2*>(params.v_params_ptr), // V_params_ptr + {params.seqlen_k, params.d_vparams, params.h_k, batch_k}, // shape_V_params + {_1{}, params.v_params_dim_stride, params.v_params_head_stride, params.v_params_batch_stride}, // stride_V_params + + static_cast(params.knew_ptr), + {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new + {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new + static_cast(params.vnew_ptr), + {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new + static_cast(params.rotary_cos_ptr), + {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter + {params.rotary_dim / 2, _1{}}, // stride_rotary_cos + static_cast(params.rotary_sin_ptr), + {params.rotary_dim / 2, _1{}}, // stride_rotary_sin + params.is_rotary_interleaved, + params.page_table, + // if page_size is not set, avoid dividing by zero + {params.kv_batch_idx ? params.b_k : params.b, !PagedKV ? 0 : params.seqlen_k / params.page_size}, // shape_page_table + {params.page_table_batch_stride, _1{}}, // stride_page_table + params.scale_softmax, + params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr, + {params.q_descale_batch_stride, params.q_descale_head_stride}, + {params.k_descale_batch_stride, params.k_descale_head_stride}, + {params.v_descale_batch_stride, params.v_descale_head_stride}, + params.window_size_left, params.window_size_right, params.sink_token_length, + params.softcap, + params.num_splits, + params.kv_batch_idx, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, + params.leftpad_k, + }; + typename CollectiveEpilogue::Arguments epilogue_args { + static_cast(!Split ? params.o_ptr : params.oaccum_ptr), + {seqlen_q, params.d, params.h, batch_q, params.num_splits}, // shape_O + {!Split ? params.o_row_stride : params.oaccum_row_stride, + _1{}, + !Split ? params.o_head_stride : params.oaccum_head_stride, + !is_varlen_q ? (!Split ? params.o_batch_stride : params.oaccum_batch_stride) : 0, + !Split ? 0 : params.oaccum_split_stride}, // stride_O + static_cast(!Split ? params.softmax_lse_ptr : params.softmax_lseaccum_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, !Split ? 0 : params.h * seqlen_q * batch_q}, // stride_LSE + params.h_k, + params.cu_seqlens_q, params.seqused_q + }; + + int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k); + int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{})); + num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{})); + + typename flash::TileSchedulerArguments scheduler_args { + num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, + params.h / params.h_k, + params.seqlen_q, + params.seqlen_k, params.d, sizeof(Element), + params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q + }; + + int device; + CHECK_CUDA(cudaGetDevice(&device)); + typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ + mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args + }); + + dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params); + dim3 block_dims = AttnKernel::get_block_shape(); + int smem_size = AttnKernel::SharedStorageSize; + + #if DEBUG + printf("Arch: %d kHeadDim: %d ClusterM: %d\n", Arch, kHeadDim, ClusterM); + printf("kStages: %d\n", kStages); + printf("grid_dims: %d %d %d\n", grid_dims.x, grid_dims.y, grid_dims.z); + printf("block_dims: %d %d %d\n", block_dims.x, block_dims.y, block_dims.z); + #endif + + if constexpr (size(ClusterShape{}) > 1) { + void const* kernel = (void const*) cutlass::device_kernel; + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); + cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; + cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params); + } else { + auto kernel = cutlass::device_kernel; + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(kernel_params); + } + CHECK_CUDA_KERNEL_LAUNCH(); + +} + +template +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); + static constexpr bool Is_FP8 = false; + using T_out = std::conditional_t, float>; + + // CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { + // VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { + // static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; + // VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { + // // Only needed here to decide if we should use cluster + // static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; + + // static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; + // APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // // Only use Cluster if number of tiles along seqlen_q is even and not varlen + // CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + // static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + + // printf("ClusterM: %d Is_causal: %d Is_local: %d Split: %d PagedKV: %d Varlen: %d AppendKV: %d PackGQA: %d V_colmajor: %d\n", ClusterM, Is_causal, Is_local, Split, PagedKV, Varlen, AppendKV, PackGQA, V_colmajor); + + + + static constexpr int ClusterM = 1; + static constexpr bool Is_causal = false; + static constexpr bool Is_local = false; + static constexpr bool Varlen = false; + static constexpr bool AppendKV = false; + static constexpr bool V_colmajor = false; + + auto tile_size = tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap); + // printf("ClusterM: %d tile_size: %d %d %d %d\n", ClusterM, std::get<0>(tile_size), std::get<1>(tile_size), std::get<2>(tile_size), std::get<3>(tile_size)); + + run_flash_fwd(params, stream); + // }); + // }); + // }); + // }); + // }); + +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +#define DEFINE_FLASH_QPACK_KERNEL(kernelName, ...) \ +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_QPACK_KERNEL(flash_qpack_kernel) { + #if defined(ARCH_SUPPORTS_FLASH) + flash::compute_qpack(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +template +void run_flash_qpack(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + + const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + dim3 grid(num_n_block, params.b, params.h); + + auto kernel = &flash_qpack_kernel; + + if (smem_size >= 48 * 1024) { + CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + + kernel<<>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); + +} + +template +void run_kvcache_qpack_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + constexpr static int kBlockN = num_bits == 4 ? 128 : 256; + + run_flash_qpack>(params, stream); +} \ No newline at end of file diff --git a/hopper/src/genfile/flash_fwd_combine.cu b/hopper/src/genfile/flash_fwd_combine.cu new file mode 100644 index 0000000..d31ce1c --- /dev/null +++ b/hopper/src/genfile/flash_fwd_combine.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "../flash_fwd_combine_launch_template.h" + +// template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +// template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +// template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); + +// template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +// template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); + +// template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +// template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +// template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/src/genfile/flash_fwd_hdim128_fp16_split_sm90.cu b/hopper/src/genfile/flash_fwd_hdim128_fp16_split_sm90.cu new file mode 100644 index 0000000..46ced5f --- /dev/null +++ b/hopper/src/genfile/flash_fwd_hdim128_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "../flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<90, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu b/hopper/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu new file mode 100644 index 0000000..166362b --- /dev/null +++ b/hopper/src/genfile/flash_qpack_hdim128_fp16_sm80_4bit.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "../flash_fwd_launch_template.h" + + +// template<> +// void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// run_kvcache_qpack_hdim128(params, stream); +// } + + +template<> +void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_kvcache_qpack_hdim128(params, stream); +} + diff --git a/hopper/src/include/block_info.h b/hopper/src/include/block_info.h new file mode 100644 index 0000000..2d6c6ea --- /dev/null +++ b/hopper/src/include/block_info.h @@ -0,0 +1,50 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + + template + __device__ BlockInfo(const Params ¶ms, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) + , sum_s_k(-1) + , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + , seqlen_k_cache(params.seqlen_k) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + { + } + + template + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; + + // return bidb * batch_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int leftpad_k; + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/hopper/src/include/dequantize.h b/hopper/src/include/dequantize.h new file mode 100644 index 0000000..87ff21c --- /dev/null +++ b/hopper/src/include/dequantize.h @@ -0,0 +1,471 @@ +#pragma once + +#include +#include +#include +#include + + +namespace quant { + +using namespace cute; + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core +// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + + +// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to +// automatically recognize it in all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) + ); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. +// We mostly follow the strategy in the link below, with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragA lop3_dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = q >> 8; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo_1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); // 0,4 + int hi_1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // 1,5 + int lo_2 = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s, LO, EX); // 2,6 + int hi_2 = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s, HI, EX); // 3,7 + + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. + const int SUB = 0x64006400; // 0x64086408 + const int MUL = 0x2c002c00; // {1/16, 1/16} + const int ADD = 0xd400d400; // 0xd480d480 + + FragA frag_a; + frag_a[0] = __hsub2( + *reinterpret_cast(&lo_1), + *reinterpret_cast(&SUB) + ); // 0,4 + frag_a[1] = __hfma2( + *reinterpret_cast(&hi_1), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); // 1,5 + frag_a[2] = __hsub2( + *reinterpret_cast(&lo_2), + *reinterpret_cast(&SUB) + ); // 2,6 + frag_a[3] = __hfma2( + *reinterpret_cast(&hi_2), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); // 3,7 + + return frag_a; +} + + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. +// We mostly follow the strategy in the link below, with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB lop3_dequant_2bit(int q) { + const int LO = 0x00030003; + const int HI = 0x00300030; + const int EX = 0x64006400; + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = q >> 8; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo_1_a = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); // 0,8 + int lo_1_b = lop3<(0xf0 & 0xcc) | 0xaa>(q >> 2, LO, EX); // 1,9 + int hi_1_a = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // 2,10 + int hi_1_b = lop3<(0xf0 & 0xcc) | 0xaa>(q >> 2, HI, EX); // 3,11 + int lo_2_a = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s, LO, EX); // 4,12 + int lo_2_b = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s >> 2, LO, EX); // 5,13 + int hi_2_a = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s, HI, EX); // 6,14 + int hi_2_b = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s >> 2, HI, EX); // 7,15 + + + // int hi_2 = lop3<(0xf0 & 0xcc) | 0xaa>(top_i4s, HI, EX); // 3,7 + + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. + const int SUB = 0x64006400; // {1024, 1024} 0x64086408 + const int MUL = 0x2c002c00; // {1/16, 1/16} + const int ADD = 0xd400d400; // {-64, -64} 0xd480d480 + + FragB frag_b; + frag_b[0] = __hsub2( + *reinterpret_cast(&lo_1_a), + *reinterpret_cast(&SUB) + ); // 0,8 + frag_b[1] = __hsub2( + *reinterpret_cast(&lo_1_b), + *reinterpret_cast(&SUB) + ); // 1,9 + frag_b[2] = __hfma2( + *reinterpret_cast(&hi_1_a), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); // 2,10 + frag_b[3] = __hfma2( + *reinterpret_cast(&hi_1_b), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); // 3,11 + frag_b[4] = __hsub2( + *reinterpret_cast(&lo_2_a), + *reinterpret_cast(&SUB) + ); // 4,12 + frag_b[5] = __hsub2( + *reinterpret_cast(&lo_2_b), + *reinterpret_cast(&SUB) + ); // 5,13 + frag_b[6] = __hfma2( + *reinterpret_cast(&hi_2_a), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); // 6,14 + frag_b[7] = __hfma2( + *reinterpret_cast(&hi_2_b), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); // 7,15 + + return frag_b; +} + + + +////////////////////////////////////////////////////////////////////////////// +// Loading params +////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ +void +load_params_Kchannel( + Tensor0 & scales, + Tensor1 & zeros, + Tensor2 const& params, + int tidx, + int i, + const int num_params +) { + CUTE_UNROLL + for (int m = 0; m < size<1>(scales); ++m) { + CUTE_UNROLL + for (int j = 0; j < size<0>(scales); ++j) { + // seems no one can know why is this offset ... + scales(j, m, i) = params(m * num_params + j % num_params, 0 + 8 * i + 4 * (j / num_params) + tidx % 4); + zeros(j, m, i) = params(m * num_params + j % num_params, 64 + 8 * i + 4 * (j / num_params) + tidx % 4); + } + } +} + +template +__forceinline__ __device__ +void +load_params_Ktensor( + Tensor0_g & scales, + Tensor1_g & zeros, + Tensor2_g const& params, + int tidx, + const int num_params +) { + CUTE_UNROLL + for (int j = 0; j < size<0>(scales); ++j) { + scales(j) = params(128 * (j / num_params / 2) + 0 + 32 * ((j / num_params) % 2) + tidx / 4, j % num_params); + zeros(j) = params(128 * (j / num_params / 2) + 64 + 32 * ((j / num_params) % 2) + tidx / 4, j % num_params); + // scales(j) = params(0 + 32 * (j / num_params) + tidx / 4, j % num_params); + // zeros(j) = params(64 + 32 * (j / num_params) + tidx / 4, j % num_params); + } + + // CUTE_UNROLL + // for (int j = 0; j < size<0>(scales); ++j) { + // params(0 + 32 * (j / num_params) + threadIdx.x / 4, j % num_params) = scales(j); + // params(64 + 32 * (j / num_params) + threadIdx.x / 4, j % num_params) = zeros(j); + // } +} + +template +__forceinline__ __device__ +void +load_params_Vtensor( + Tensor0 & scales, + Tensor1 & zeros, + Tensor2 const& params, + int tidx, + int i, + const int num_params +) { + const int num_params_2 = num_bits == 2 ? num_params / 2 : num_params; + CUTE_UNROLL + for (int j = 0; j < size<0>(scales); ++j) { + // seems no one can know why is this offset ... + scales(j, i) = params(128 * (i / 8) + 0 + 8 * (i % 8) + 4 * (j / num_params_2) + tidx % 4, j % num_params_2); + zeros(j, i) = params(128 * (i / 8) + 64 + 8 * (i % 8) + 4 * (j / num_params_2) + tidx % 4, j % num_params_2); + } +} + +////////////////////////////////////////////////////////////////////////////// +// Dequantization +////////////////////////////////////////////////////////////////////////////// + + +template +struct dequant_kc_vt; + +template +struct dequant_kc_vt<2, SourceEngine, SourceLayout, TargetEngine, TargetLayout, ScaleEngine, ScaleLayout, ZeroEngine, ZeroLayout> { + static constexpr int num_bits = 2; + CUTE_DEVICE static + void apply(cute::Tensor const& source, + cute::Tensor const& target, + cute::Tensor const& scales, + cute::Tensor const& zeros, + const int num_params) { + using TQ = cute::uint16_t; + using TQ2 = cute::uint32_t; + using T = typename TargetEngine::value_type; + using T2 = __half2; + const int num_params_ = num_params / 2; // TODO: only for g128 + const int pack_num = 4 / num_params_; // TODO: check 4 + + // vectorize the source and target + auto scales_vec = cute::recast(scales); + auto zeros_vec = cute::recast(zeros); + auto source_vec = cute::recast(source); + auto target_vec = cute::recast(target); + + const int channel_stride = size<0>(source_vec); + + CUTE_UNROLL + for (int i = 0; i < cute::size<0>(source_vec); ++i) { + + CUTE_UNROLL + for (int p = 0; p < cute::size<1>(source_vec); ++p) { + auto src_crd = cute::make_coord(i, p); + auto src_raw = source_vec(src_crd); + auto src_val = lop3_dequant_2bit(src_raw); + + CUTE_UNROLL + for (int j = 0; j < size<1>(target_vec); ++j) { + target_vec(i, j) = __hfma2(src_val[j], scales_vec(i + j / pack_num * channel_stride), zeros_vec(i + j / pack_num * channel_stride)); + } + + // target_vec(i,0) = __hfma2(src_val[0], scales_vec(i), zeros_vec(i)); + // target_vec(i,1) = __hfma2(src_val[1], scales_vec(i + 1 / pack_num * channel_stride), zeros_vec(i + 1 / pack_num * channel_stride)); + // target_vec(i,2) = __hfma2(src_val[2], scales_vec(i + 2 / pack_num * channel_stride), zeros_vec(i + 2 / pack_num * channel_stride)); + // target_vec(i,3) = __hfma2(src_val[3], scales_vec(i + 3 / pack_num * channel_stride), zeros_vec(i + 3 / pack_num * channel_stride)); + // target_vec(i,4) = __hfma2(src_val[4], scales_vec(i + 4 / pack_num * channel_stride), zeros_vec(i + 4 / pack_num * channel_stride)); + // target_vec(i,5) = __hfma2(src_val[5], scales_vec(i + 5 / pack_num * channel_stride), zeros_vec(i + 5 / pack_num * channel_stride)); + // target_vec(i,6) = __hfma2(src_val[6], scales_vec(i + 6 / pack_num * channel_stride), zeros_vec(i + 6 / pack_num * channel_stride)); + // target_vec(i,7) = __hfma2(src_val[7], scales_vec(i + 7 / pack_num * channel_stride), zeros_vec(i + 7 / pack_num * channel_stride)); + + // target_vec(i,0) = __hfma2(src_val[0], scales_vec(0), zeros_vec(0)); + // target_vec(i,1) = __hfma2(src_val[1], scales_vec(0), zeros_vec(0)); + // target_vec(i,2) = __hfma2(src_val[2], scales_vec(0), zeros_vec(0)); + // target_vec(i,3) = __hfma2(src_val[3], scales_vec(0), zeros_vec(0)); + // target_vec(i,4) = __hfma2(src_val[4], scales_vec(0), zeros_vec(0)); + // target_vec(i,5) = __hfma2(src_val[5], scales_vec(0), zeros_vec(0)); + // target_vec(i,6) = __hfma2(src_val[6], scales_vec(0), zeros_vec(0)); + // target_vec(i,7) = __hfma2(src_val[7], scales_vec(0), zeros_vec(0)); + + // target_vec(i,0) = src_val[0]; + // target_vec(i,1) = src_val[1]; + // target_vec(i,2) = src_val[2]; + // target_vec(i,3) = src_val[3]; + // target_vec(i,4) = src_val[4]; + // target_vec(i,5) = src_val[5]; + // target_vec(i,6) = src_val[6]; + // target_vec(i,7) = src_val[7]; + } + } + } +}; + +template +struct dequant_kc_vt<4, SourceEngine, SourceLayout, TargetEngine, TargetLayout, ScaleEngine, ScaleLayout, ZeroEngine, ZeroLayout> { + static constexpr int num_bits = 4; + CUTE_DEVICE static + void apply(cute::Tensor const& source, + cute::Tensor const& target, + cute::Tensor const& scales, + cute::Tensor const& zeros, + const int num_params) { + using TQ = cute::uint16_t; + using TQ2 = cute::uint32_t; + using T = typename TargetEngine::value_type; + using T2 = __half2; + const int pack_num = 4 / num_params; + + // vectorize the source and target + auto scales_vec = cute::recast(scales); + auto zeros_vec = cute::recast(zeros); + auto source_vec = cute::recast(source); + auto target_vec = cute::recast(target); + + const int channel_stride = cute::size<0>(source_vec); + const int scales_stride = cute::size<0>(scales_vec); + + CUTE_UNROLL + for (int i = 0; i < cute::size<0>(source_vec); ++i) // 2 + { + CUTE_UNROLL + for (int p = 0; p < cute::size<1>(source_vec); ++p) // 1 + { + auto src_crd = cute::make_coord(i, p); + auto src_raw = source_vec(src_crd); + auto src_val = lop3_dequant(src_raw); + + auto col_offset = p * num_bits; + + auto tgt0_crd = cute::make_coord(i, col_offset + 0); + auto tgt1_crd = cute::make_coord(i, col_offset + 1); + auto tgt2_crd = cute::make_coord(i, col_offset + 2); + auto tgt3_crd = cute::make_coord(i, col_offset + 3); + + // TODO: hard code for now 2 + int params_crd = i; + + target_vec(tgt0_crd) = __hfma2(src_val[0], scales_vec(params_crd + p * scales_stride), zeros_vec(params_crd + p * scales_stride)); + target_vec(tgt1_crd) = __hfma2(src_val[1], scales_vec(params_crd + 1 / pack_num * channel_stride + p * scales_stride), zeros_vec(params_crd + 1 / pack_num * channel_stride + p * scales_stride)); + target_vec(tgt2_crd) = __hfma2(src_val[2], scales_vec(params_crd + 2 / pack_num * channel_stride + p * scales_stride), zeros_vec(params_crd + 2 / pack_num * channel_stride + p * scales_stride)); + target_vec(tgt3_crd) = __hfma2(src_val[3], scales_vec(params_crd + 3 / pack_num * channel_stride + p * scales_stride), zeros_vec(params_crd + 3 / pack_num * channel_stride + p * scales_stride)); + + // target_vec(tgt0_crd) = src_val[0]; + // target_vec(tgt1_crd) = src_val[1]; + // target_vec(tgt2_crd) = src_val[2]; + // target_vec(tgt3_crd) = src_val[3]; + } + } + } +}; + +template +CUTE_DEVICE +void +dequant_Kchannel_Vtensor( + cute::Tensor const& source, + cute::Tensor const& target, + cute::Tensor const& scales_vec, + cute::Tensor const& zeros_vec, + const int num_params=1 +) { + dequant_kc_vt::apply(source, target, scales_vec, zeros_vec, num_params); +} + +template +CUTE_DEVICE +void +dequantize_Ktensor( + cute::Tensor const& source_, + cute::Tensor & target_, + TensorParamsG1 & scales_k_g_vec, + TensorParamsG2 & zeros_k_g_vec, + int num_bits, + int group_size, + int ii +) { + using TQ = cute::uint16_t; + using TQ2 = cute::uint32_t; + using T = typename TargetEngine::value_type; + using T2 = __half2; + + // vectorize the source and target + auto source = source_(_,_,_,_0{}); + auto target = target_(_,_,_,_0{}); + + static constexpr int kNumBits = 4; + const int num_params = 128 / group_size; + const int ki = size<2>(target) / num_params; + + auto scales_k_g = cute::recast(scales_k_g_vec); + auto zeros_k_g = cute::recast(zeros_k_g_vec); + auto source_vec = cute::recast(source); + auto target_vec = cute::recast(target); + + const int tile_j = size<2>(target) != size<2>(source) ? 2 : 1; + + CUTE_UNROLL + for (int i = 0; i < cute::size<0>(source_vec); ++i) + { + auto src_crd = cute::make_coord(0, 0, 0); + for (int p = 0; p < tile_j; ++p) { + src_crd = tile_j == 1 ? cute::make_coord(i, 0, ii) : cute::make_coord(i, 0, 8 * (ii / 4) + ii % 4 + p * 4); + auto src_raw = source_vec(src_crd); + auto src_val = lop3_dequant(src_raw); + + auto col_offset = p * kNumBits; + + auto tgt0_crd = cute::make_coord(i, col_offset + 0, ii); + auto tgt1_crd = cute::make_coord(i, col_offset + 1, ii); + auto tgt2_crd = cute::make_coord(i, col_offset + 2, ii); + auto tgt3_crd = cute::make_coord(i, col_offset + 3, ii); + + // Create half2 values for scales and zeros + half2 scales_k_g_0 = __half2half2(__half(scales_k_g(ii / ki + col_offset * num_params + 0 * num_params))); + half2 scales_k_g_1 = __half2half2(__half(scales_k_g(ii / ki + col_offset * num_params + 1 * num_params))); + half2 scales_k_g_2 = __half2half2(__half(scales_k_g(ii / ki + col_offset * num_params + 2 * num_params))); + half2 scales_k_g_3 = __half2half2(__half(scales_k_g(ii / ki + col_offset * num_params + 3 * num_params))); + + half2 zeros_k_g_0 = __half2half2(__half(zeros_k_g(ii / ki + col_offset * num_params + 0 * num_params))); + half2 zeros_k_g_1 = __half2half2(__half(zeros_k_g(ii / ki + col_offset * num_params + 1 * num_params))); + half2 zeros_k_g_2 = __half2half2(__half(zeros_k_g(ii / ki + col_offset * num_params + 2 * num_params))); + half2 zeros_k_g_3 = __half2half2(__half(zeros_k_g(ii / ki + col_offset * num_params + 3 * num_params))); + + target_vec(tgt0_crd) = __hfma2(src_val[0], scales_k_g_0, zeros_k_g_0); + target_vec(tgt1_crd) = __hfma2(src_val[1], scales_k_g_1, zeros_k_g_1); + target_vec(tgt2_crd) = __hfma2(src_val[2], scales_k_g_2, zeros_k_g_2); + target_vec(tgt3_crd) = __hfma2(src_val[3], scales_k_g_3, zeros_k_g_3); + + // target_vec(tgt0_crd) = src_val[0]; + // target_vec(tgt1_crd) = src_val[1]; + // target_vec(tgt2_crd) = src_val[2]; + // target_vec(tgt3_crd) = src_val[3]; + } + + + } + +} + +} // namespace quant \ No newline at end of file diff --git a/hopper/src/include/epilogue_fwd.hpp b/hopper/src/include/epilogue_fwd.hpp new file mode 100644 index 0000000..0f91606 --- /dev/null +++ b/hopper/src/include/epilogue_fwd.hpp @@ -0,0 +1,419 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include // For FastDivMod +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/collective/builders/sm90_common.inl" + +#include "seqlen.h" +#include "named_barrier.hpp" +#include "pack_gqa.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveEpilogueFwd { + + using TileShape_MNK = TileShape_MNK_; + using ClusterShape = ClusterShape_; + using Element = Element_; + using ArchTag = ArchTag_; + static constexpr int NumEpilogueThreads = NumEpilogueThreads_; + static constexpr bool Varlen = Varlen_; + static constexpr bool PackGQA = PackGQA_; + static constexpr bool Use_smem = sizeof(Element) <= 2; + static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && Use_smem && !PackGQA; + + static_assert(ArchTag::kMinComputeCapability >= 80); + static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; + + // These are for storing the output tensor without TMA (e.g., for setting output to zero) + static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements + // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times + // we need to call divmod. + static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + // static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerStore, NumEpilogueThreads); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; + // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); + static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow"); + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 2>(TileShape_MNK{}))); + static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); + static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutO = std::conditional_t= 90, SmemLayoutOTMA, SmemLayoutOSTS>; + + using ShapeO = cute::Shape; // (seqlen_q, d, head, batch, num_splits) + using StrideO = cute::Stride; + using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits) + // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) + using ShapeOPacked = std::conditional_t, int32_t, int32_t, int32_t, int32_t>>; + using StrideOPacked = std::conditional_t, _1, int64_t, int64_t, int64_t>>; + // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits) + using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; + using StrideLSEPacked = std::conditional_t, int64_t, int64_t, int64_t>>; + + using CopyOpR2S = std::conditional_t< + ArchTag::kMinComputeCapability >= 90, + // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16) + decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), + AutoVectorizingCopyWithAssumedAlignment<128> + >; + using SmemCopyAtomO = Copy_Atom; + + // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{}); + // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment"); + // struct TensorStorage : cute::aligned_struct { + // cute::array_aligned : 0, SmemAlignmentO> smem_o; + // }; + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned : 0> smem_o; + }; + + using TMA_O = std::conditional_t< + Use_TMA_O, + decltype(make_tma_copy( + GmemTiledCopyOTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), + SmemLayoutOTMA{}, + select<0, 2>(TileShape_MNK{}), + _1{})), // no mcast for O + std::nullptr_t + >; + + // Host side kernel arguments + struct Arguments { + Element* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + float* ptr_LSE; + StrideLSE const stride_LSE; + int32_t const nheads_kv; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Device side kernel params + struct Params { + Element* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + ShapeOPacked const shape_O_packed; + StrideOPacked const stride_O_packed; + float* ptr_LSE; + StrideLSE const stride_LSE; + ShapeLSEPacked const shape_LSE_packed; + StrideLSEPacked const stride_LSE_packed; + cutlass::FastDivmod qhead_per_khead_divmod; + TMA_O tma_store_O; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); + TMA_O tma_store_O = [&]{ + if constexpr (Use_TMA_O) { + return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{}); // no mcast + } else { + return nullptr; + } + }(); + // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits) + int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv); + auto const shape_O_packed = cute::conditional_return( + args.shape_O, + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) + ); + auto const stride_O_packed = cute::conditional_return( + args.stride_O, + make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O)) + ); + // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits) + auto const shape_LSE_packed = cute::conditional_return( + select<0, 2, 3, 4>(args.shape_O), + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) + ); + auto const stride_LSE_packed = cute::conditional_return( + args.stride_LSE, + make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE)) + ); + return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed, + args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed, + cutlass::FastDivmod(qhead_per_khead), + tma_store_O, args.cu_seqlens, args.seqused}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + if constexpr (Use_TMA_O) { + cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor()); + } + } + + template + CUTLASS_DEVICE void + store(Params const& params, + FrgTensorO const& tOrO, + FrgTensorLSE const& lse, + SharedStorage& shared_storage, + TiledMma tiled_mma, + int thread_idx, + cute::tuple const& block_coord + ) { + + auto [m_block, bidh, bidb, split_idx] = block_coord; + Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); + // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO); + + Tensor tOrO_out = make_tensor_like(tOrO); + flash::convert_type_out(tOrO, tOrO_out); + if constexpr (FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4)) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } + + // Make sure all WGs have finished reading V + // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that + // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with + // cp.async if we need). + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + + // Step 1: Write O from rmem -> smem + if constexpr (Use_smem) { + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + if constexpr (Use_TMA_O) { + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } else { + flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + } else { + if constexpr (ArchTag::kMinComputeCapability >= 90) { + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + } + + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; + bool is_varlen = Varlen && params.cu_seqlens; + int offset_o = seqlen_info.offset; + int seqlen_o = seqlen_info.seqlen; + + // Step 2: Write LSE from rmem -> gmem + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + // (MMA,MMA_M,MMA_K) + Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + static_assert(decltype(size<0, 0>(taccOcO))::value == 2); + static_assert(decltype(size<0, 1>(taccOcO))::value == 2); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + + using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; + + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); + // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } + if constexpr (!PackGQA) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); + if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } + } + } else { + PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + + // Step 3: Write O from smem -> gmem + if constexpr (Use_TMA_O) { + Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); + Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + auto block_tma_O = params.tma_store_O.get_slice(_0{}); + Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) + Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) + int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0); + if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { + cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + if (cute::elect_one_sync()) { + cute::copy(params.tma_store_O, tOsO, tOgO); + tma_store_arrive(); + tma_store_wait<0>(); + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + } + } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } + if constexpr (Use_smem) { + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOrO = make_fragment_like(tOsO); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + if constexpr (ArchTag::kMinComputeCapability >= 90) { + cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + if constexpr (!PackGQA) { + // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor tOpO = make_tensor(make_shape(size<2>(tOsO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + } else { + // If PackGQA, we split the work of compute O_ptr among threads in the same row + PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + } else { + // We already arrived on barrier_O earlier + if constexpr (!PackGQA) { + static constexpr int kGmemElemsPerStoreDirect = 2; + cute::Copy_Atom, Element> gmem_copy_direct; + // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor tOrO_rowcol = make_tensor(tOrO_out.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); + Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); + Tensor tOgO = thread_mma.partition_C(gO); + Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout())); + Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int>{}); + Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); + #pragma unroll + for (int m = 0; m < size(taccOcO_row); ++m) { + if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) { + #pragma unroll + for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) { + if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) { + cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k)); + } + } + } + } + } else { + PackGQAt::store_O_direct(mO, tOrO_out, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + } + } + } + + CUTLASS_DEVICE void + store_tail() { + // Don't need to do tma_store_wait<0>() here since we already did in @store + } + + // Write 0 to output and -inf to LSE + template + CUTLASS_DEVICE void + store_zero( + Params const& params, + int thread_idx, + cute::tuple const& block_coord + ) { + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + auto [m_block, bidh, bidb, split_idx] = block_coord; + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; + bool const is_varlen = Varlen && params.cu_seqlens; + int offset_o = seqlen_info.offset; + int seqlen_o = seqlen_info.seqlen; + int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor gLSE = local_tile(mLSE, Shape>{}, make_coord(m_block)); + + static_assert(kBlockM <= NumEpilogueThreads); + if (thread_idx < kBlockM) { + const int row = m_block * kBlockM + thread_idx; + if constexpr (!PackGQA) { + if (row < seqlen_o) { mLSE(row) = -INFINITY; } + } else { + if (row < seqlen_o * qhead_per_khead) { + int m_idx, h_idx; + m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row); + // mLSE shape shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" + mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY; + } + } + } + + if constexpr (!Clear_O) { return; } + + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + if constexpr (!PackGQA) { + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } + Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_fragment_like(tOgO); + cute::clear(tOrO); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + } else { + // If PackGQA, we split the work of compute O_ptr among threads in the same row + using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; + Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); + cute::clear(tOrO); + PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } + + } + +}; + +} // namespace flash diff --git a/hopper/src/include/flash.h b/hopper/src/include/flash.h new file mode 100644 index 0000000..7d0058d --- /dev/null +++ b/hopper/src/include/flash.h @@ -0,0 +1,239 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ K_pack_ptr; + void *__restrict__ k_params_ptr; + void *__restrict__ v_ptr; + void *__restrict__ v_pack_ptr; + void *__restrict__ v_params_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + + index_t k_batch_stride; + index_t K_pack_batch_stride; + index_t k_params_batch_stride; + + index_t v_batch_stride; + index_t v_pack_batch_stride; + index_t v_params_batch_stride; + + index_t q_row_stride; + + index_t k_row_stride; + index_t K_pack_row_stride; + index_t k_params_row_stride; + + index_t v_row_stride; + index_t v_pack_row_stride; + index_t v_params_row_stride; + + index_t q_head_stride; + index_t k_head_stride; + index_t K_pack_head_stride; + index_t k_params_head_stride; + + index_t v_head_stride; + index_t v_pack_head_stride; + index_t v_params_head_stride; + + index_t k_params_dim_stride; + + index_t v_dim_stride; + index_t v_params_dim_stride; + + // The number of heads. + int h, h_k; + int h_h_k_ratio; // precompute h / h_k, + + std::string quant_mode; + int group_size; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + using index_t = int64_t; + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // For FP8 scaling + float * __restrict__ q_descale_ptr; + float * __restrict__ k_descale_ptr; + float * __restrict__ v_descale_ptr; + index_t q_descale_batch_stride; + index_t q_descale_head_stride; + index_t k_descale_batch_stride; + index_t k_descale_head_stride; + index_t v_descale_batch_stride; + index_t v_descale_head_stride; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_k_pack, seqlen_knew, seqlen_k_params, d, d_kpack, d_vpack, d_vparams,seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + int total_q, total_k, total_knew; + int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q + + // The scaling factors for the kernel. + float scale_softmax; + float softcap; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + int * __restrict__ cu_seqlens_knew; + int * __restrict__ leftpad_k; + + // If provided, the actual length of each q/k sequence. + int *__restrict__ seqused_q; + int *__restrict__ seqused_k; + + // The stride between rows of Oaccum. + index_t oaccum_split_stride; + index_t oaccum_batch_stride; + index_t oaccum_row_stride; + index_t oaccum_head_stride; + + // The stride between rows of LSEaccum. + index_t lseaccum_split_stride; + index_t lseaccum_batch_stride; + index_t lseaccum_head_stride; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ kv_batch_idx; + + // Paged KV cache + int * __restrict__ page_table; + index_t page_table_batch_stride; + int page_size; + int num_pages; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + + // Local window size + int window_size_left, window_size_right; + int sink_token_length; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_fp32; + bool is_e4m3; + bool is_causal; + bool is_local; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + bool pack_gqa; + + int * __restrict__ tile_count_semaphore; + + int arch; + int num_sm; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + using index_t = int64_t; + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; + void *__restrict__ softmax_lse_log2_ptr; + + int *__restrict__ dq_semaphore; + int *__restrict__ dk_semaphore; + int *__restrict__ dv_semaphore; + + bool deterministic; + index_t dq_accum_split_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); + +template +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); + +template +void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/src/include/flash_fwd_combine_kernel.h b/hopper/src/include/flash_fwd_combine_kernel.h new file mode 100644 index 0000000..aaec31e --- /dev/null +++ b/hopper/src/include/flash_fwd_combine_kernel.h @@ -0,0 +1,453 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include + +#include "seqlen.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +class FlashAttnFwdCombine { + +public: + + // Type Aliases + using TileShape_MK = TileShape_MK_; + using ArchTag = ArchTag_; + static constexpr int kMaxSplits = 1 << kLogMaxSplits_; + static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float))); + static_assert(AlignmentLSE >= 1); + static constexpr int kStages = 4; + + static_assert(ArchTag::kMinComputeCapability >= 75); + static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; + + static constexpr uint32_t MaxThreadsPerBlock = kNThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 2; + + static constexpr int kBlockM = get<0>(TileShape_MK{}); + static constexpr int kHeadDim = get<1>(TileShape_MK{}); + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); + using GmemCopyAtom = std::conditional_t< + Has_cp_async, + cute::Copy_Atom, ElementPartial>, + cute::Copy_Atom, ElementPartial> + >; + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); + using GmemTiledCopyAccum = decltype( + make_tiled_copy(GmemCopyAtom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 4 vals per load + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 4 vals per load + + using AlignmentTypeLSE = cute::uint_byte_t(sizeof(float)) * AlignmentLSE>; + static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float); + static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE"); + static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8"); + static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8))); + static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE; + static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE"); + using GmemLayoutAtomLSE = Layout, Int>, + Stride, _1>>; + static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0); + using GmemCopyAtomLSE = std::conditional_t< + Has_cp_async, + cute::Copy_Atom, float>, + cute::Copy_Atom, float> + >; + using GmemTiledCopyLSE = decltype( + make_tiled_copy(GmemCopyAtomLSE{}, + GmemLayoutAtomLSE{}, + Layout>>{})); // Val layout, 4 vals per load + + // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking + static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE"); + // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts + using SmemLSESwizzle = std::conditional_t< + kBlockMSmem == 8, + Swizzle<5, 0, 5>, + std::conditional_t, Swizzle<3, 2, 3>> + >; + using SmemLayoutAtomLSE = + decltype(composition(SmemLSESwizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape, Int>{})); + + using SmemLayoutO = Layout, Int, Int>, + Stride, _1, Int>>; + + // We want each column (kMaxSplits) to be processed by threads in the same warp. + // To reduce the number of shuffles, we want as few threads on the same column as possible. + // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column + // have have 64 such quads. + static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem"); + static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem; + static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp"); + using S2RLayoutAtomLSE = Layout, Int>>; + using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom{}, S2RLayoutAtomLSE{}, Layout<_1>{})); + + using ShapeOPartial = cute::Shape; // (seqlen, d, num_splits, head, batch) + using StrideOPartial = cute::Stride; + using ShapeLSEPartial = cute::Shape; // (seqlen, num_splits, head, batch) + using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch) + using ShapeO = cute::Shape; // (seqlen, d, head, batch) + using StrideO = cute::Stride; + using ShapeLSE = cute::Shape; // (seqlen, head, batch) + using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) + + struct SharedStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_lse_partial; + cute::array_aligned smem_max_valid_split; + cute::array_aligned> smem_o_partial; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + + // Device side arguments + struct Arguments { + ElementPartial const* ptr_O_partial; + ShapeOPartial const shape_O_partial; + StrideOPartial const stride_O_partial; + float const* ptr_LSE_partial; + ShapeLSEPartial const shape_LSE_partial; + StrideLSEPartial const stride_LSE_partial; + Element* ptr_O; + StrideO const stride_O; + float* ptr_LSE; + StrideLSE const stride_LSE; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Kernel entry point API + struct Params { + ElementPartial const* ptr_O_partial; + ShapeOPartial const shape_O_partial; + StrideOPartial const stride_O_partial; + float const* ptr_LSE_partial; + ShapeLSEPartial const shape_LSE_partial; + StrideLSEPartial const stride_LSE_partial; + Element* ptr_O; + StrideO const stride_O; + float* ptr_LSE; + StrideLSE const stride_LSE; + cutlass::FastDivmod seqlen_divmod, head_divmod; + int const* cu_seqlens = nullptr; + int const* seqused = nullptr; + }; + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + assert(get<1>(args.shape_LSE_partial) <= kMaxSplits); + return { + args.ptr_O_partial, + args.shape_O_partial, + args.stride_O_partial, + args.ptr_LSE_partial, + args.shape_LSE_partial, + args.stride_LSE_partial, + args.ptr_O, + args.stride_O, + args.ptr_LSE, + args.stride_LSE, + cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)), + args.cu_seqlens, + args.seqused + }; + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{}); + Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape>{}); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); + + int const thread_idx = threadIdx.x; + int const m_block = blockIdx.x; + int const batch = !Varlen ? 0 : blockIdx.y; + int const num_splits = get<1>(params.shape_LSE_partial); + flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; + int const offset = seqlen_info.offset; + int const seqlen = seqlen_info.seqlen; + int max_idx = seqlen * get<2>(params.shape_LSE_partial) * get<3>(params.shape_LSE_partial); + + cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); + + // Step 1: load LSE_partial from gmem -> smem + Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), select<1, 0, 2, 3>(params.shape_LSE_partial), select<1, 0, 2, 3>(params.stride_LSE_partial)); // (num_splits, seqlen, head, batch) + Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int>{}); + GmemTiledCopyLSE gmem_tiled_copy_LSE; + auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx); + Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE); + + // Construct identity layout for sLSE + Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m) + // Repeat the partitioning with identity layouts + Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE); + + #pragma unroll + for (int m = 0; m < size<2>(tLSEcLSE); ++m) { + int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m))); + int idx = m_block * kBlockM + mi; + if (idx < max_idx) { + int m_idx, bidh, bidb; + if constexpr (!Varlen) { + bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); + } else { + bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); + bidb = 0; + } + Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh, bidb); + #pragma unroll + for (int s = 0; s < size<1>(tLSEcLSE); ++s) { + int si = get<0>(tLSEcLSE(_0{}, s, _0{})); + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);} + if (si < num_splits) { + cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m)); + } else { + cute::fill(tLSEsLSE(_, s, m), -INFINITY); + } + } + } else { + // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem + // cute::fill(tLSEsLSE(_, _, m), -INFINITY); + } + } + if constexpr (Has_cp_async) { cute::cp_async_fence(); } + + // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2. + // We want these async loads to be in flight as we compute the LSE. + GmemTiledCopyAccum gmem_tiled_copy_O_partial; + auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx); + // Construct identity layout for gO + Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO); + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), params.shape_O_partial, params.stride_O_partial); // (seqlen, d, num_splits, head, batch) + + // Precompute these values to avoid recomputing them in the loop + Tensor tOmidx = make_tensor(make_shape(size<1>(tOcO))); + Tensor tObidh = make_tensor(make_shape(size<1>(tOcO))); + Tensor tObidb = make_tensor(make_shape(size<1>(tOcO))); + Tensor tOrOptr = make_tensor(make_shape(size<1>(tOcO))); + #pragma unroll + for (int m = 0; m < size<1>(tOcO); ++m) { + int mi = get<0>(tOcO(_0{}, m, _0{})); + int idx = m_block * kBlockM + mi; + if constexpr (!Varlen) { + tObidb[m] = params.head_divmod.divmod(tObidh(m), params.seqlen_divmod.divmod(tOmidx(m), idx)); + } else { + tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx); + tObidb[m] = 0; + } + tOrOptr[m] = &mOpartial(tOmidx(m), _0{}, _0{}, tObidh(m), tObidb(m)); + if (idx >= max_idx) { + tObidb[m] = -1; + } + } + + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + if constexpr (!(Is_even_K)) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial); } + } + + Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO); + + auto load_O_partial = [&] (int split, int stage) { + Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage); + #pragma unroll + for (int m = 0; m < size<1>(tOcO); ++m) { + if (tObidb(m) >= 0) { + Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}, _0{}).layout()); + Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tOcO); ++k) { + int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (Is_even_K || tOpO(k)) { + cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k)); + } + } + } + } + }; + + for (int s = 0; s < kStages - 1; ++s) { + if (s < num_splits) { load_O_partial(s, s); } + if constexpr (Has_cp_async) { cute::cp_async_fence(); } + } + + // Step 3: load and transpose LSE_partial from smem -> rmem + if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait(); } + __syncthreads(); + + S2RTiledCopyLSE s2r_tiled_copy_LSE; + auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx); + Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE); + Tensor ts2rrLSE = make_fragment_like(ts2rsLSE); + cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE); + + // Step 4: compute the final LSE along the split dimension + Tensor lse_sum = make_tensor(make_shape(size<2>(ts2rrLSE))); + Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE); + // We compute the max valid split for each row to short-circuit the computation later + Tensor max_valid_split = make_tensor(make_shape(size<2>(ts2rrLSE))); + static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1); + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + float lse_max = ts2rrLSE(_0{}, _0{}, m); + #pragma unroll + for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + int max_valid_idx = -1; + #pragma unroll + for (int s = 0; s < size<1>(ts2rrLSE); ++s) { + if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); } + } + MaxOp max_int_op; + max_valid_split[m] = Allreduce::run(max_valid_idx, max_int_op); + float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum_cur = 0.f; + #pragma unroll + for (int s = 0; s < size<1>(ts2rrLSE); ++s) { + float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur); + lse_sum_cur += scale; + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);} + // ts2rsLSE(_0{}, m, s) = scale; + ts2rrLSE(_0{}, s, m) = scale; + } + SumOp sum_op; + lse_sum_cur = Allreduce::run(lse_sum_cur, sum_op); + lse_sum(m) = logf(lse_sum_cur) + lse_max; + float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur; + #pragma unroll + for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; } + } + // Store the scales exp(lse - lse_logsum) back to smem + cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE); + + // Step 5: store final LSE back to gmem + auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE); + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem + int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); + int idx = m_block * kBlockM + mi; + if (idx < max_idx) { + int m_idx, bidh, bidb; + if constexpr (!Varlen) { + bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); + } else { + bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); + bidb = 0; + } + // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); + mLSE(m_idx, bidh, bidb) = lse_sum(m); + } + if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } + } + } + + // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O + __syncthreads(); + int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))]; + #pragma unroll + for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); } + Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor(TileShape_MK{})).layout(); + Tensor tOrOpartial = make_fragment_like(tOrOpartial_layout); + Tensor tOrO = make_fragment_like(tOrOpartial); + clear(tOrO); + int stage_load = kStages - 1, stage_compute = 0; + #pragma unroll 4 // Already tuned for speed + for (int s = 0; s <= thr_max_valid_split; ++s) { + Tensor scale = make_tensor(make_shape(size<1>(tOrOpartial))); + #pragma unroll + for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); } + + if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); } + if constexpr (Has_cp_async) { cute::cp_async_fence(); } + stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0; + if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait(); } + // We don't need __syncthreads() because each thread is just reading its own data from smem + cute::copy(Copy_Atom, ElementPartial>{}, + tOsOpartial(_, _, _, stage_compute), tOrOpartial); + stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0; + + #pragma unroll + for (int m = 0; m < size<1>(tOrOpartial); ++m) { + if (tObidb(m) >= 0 && scale(m) > 0.f) { + #pragma unroll + for (int k = 0; k < size<2>(tOrOpartial); ++k) { + if (Is_even_K || tOpO(k)) { + Tensor rOpartial = make_tensor_like(tOrOpartial(_, m, k)); + flash::convert_type_out(tOrOpartial(_, m, k), rOpartial); + #pragma unroll + for (int i = 0; i < size<0>(tOrOpartial); ++i) { + tOrO(i, m, k) += scale(m) * rOpartial[i]; + } + } + } + } + } + } + + // Step 7: Write the final O to gmem + Tensor rO = make_tensor_like(tOrO); + flash::convert_type_out(tOrO, rO); + auto shape_O = select<0, 1, 3, 4>(params.shape_O_partial); + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O)), shape_O, params.stride_O); + Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int>{}); + GmemTiledCopy gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + + #pragma unroll + for (int m = 0; m < size<1>(tOcO); ++m) { + if (tObidb(m) >= 0) { + #pragma unroll + for (int k = 0; k < size<2>(tOcO); ++k) { + int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (Is_even_K || tOpO(k)) { + cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m), tObidb(m))); + } + } + } + } + + } + +}; + +} // namespace flash diff --git a/hopper/src/include/flash_fwd_kernel_sm90.h b/hopper/src/include/flash_fwd_kernel_sm90.h new file mode 100644 index 0000000..b0fcbc0 --- /dev/null +++ b/hopper/src/include/flash_fwd_kernel_sm90.h @@ -0,0 +1,393 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "flash.h" +#include "seqlen.h" +#include "utils.h" +#include "softmax.h" +#include "tile_scheduler.hpp" + +namespace flash { + +using namespace cute; + +template +class FlashAttnFwdSm90 { + +public: + + // Type Aliases + using CollectiveMainloop = CollectiveMainloop_; + using CollectiveEpilogue = CollectiveEpilogue_; + static constexpr bool Is_causal = CollectiveMainloop::Is_causal; + static constexpr bool Is_local = CollectiveMainloop::Is_local; + static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); + static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; + static constexpr bool Varlen = CollectiveMainloop::Varlen; + static constexpr bool PagedKV = CollectiveMainloop::PagedKV; + static constexpr bool Split = CollectiveMainloop::Split; + static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; + static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; + static constexpr bool AppendKV = CollectiveMainloop::AppendKV; + static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q; + static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV; + static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O; + static constexpr bool PackGQA = CollectiveMainloop::PackGQA; + static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; + using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; + + // Mainloop derived types + using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; + using TiledMma0 = typename CollectiveMainloop::TiledMma0; + using TiledMma1 = typename CollectiveMainloop::TiledMma1; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using BarrierQ = std::conditional_t; + // Epilogue derived types + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + using TileScheduler = TileScheduler_; + using TileSchedulerArguments = typename flash::TileSchedulerArguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma0{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma0{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + + /// Register requirement for Load and Math WGs + // If we use cp.async to load K and V, we need more registers for the producer WG. + static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32); + static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160); + // static constexpr uint32_t LoadRegisterRequirement = 40; + // static constexpr uint32_t MmaRegisterRequirement = 232; + // If you want to print from the producer warp, you'd need to increase the number of registers + // Otherwise you'll get CUDA error. + // static constexpr uint32_t LoadRegisterRequirement = 40; + // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152; + + // Kernel level shared memory storage + // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v). + static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))); + static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + union { + struct { + cute::array padding_; + typename CollectiveMainloop::TensorStorage mainloop; + }; + // We want smem_o to line up with the start of smem_v + typename CollectiveEpilogue::TensorStorage epilogue; + }; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + alignas(16) BarrierQ barrier_Q; + alignas(16) cutlass::arch::ClusterBarrier barrier_O; + alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k; + alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v; + alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt; + alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new; + alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new; + alignas(16) typename TileScheduler::SharedStorage smem_scheduler; + } pipelines; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + cutlass::KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + return { + CollectiveMainloop::to_underlying_arguments(args.mainloop), + CollectiveEpilogue::to_underlying_arguments(args.epilogue), + hw_info, + TileScheduler::to_underlying_arguments(args.scheduler) + }; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + + using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK; + using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV; + using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt; + using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew; + using PipelineState = typename CollectiveMainloop::PipelineState; + using PipelineParamsK = typename MainloopPipelineK::Params; + using PipelineParamsV = typename MainloopPipelineV::Params; + using PipelineParamsVt = typename MainloopPipelineVt::Params; + using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + + if (warp_idx == 0 && lane_predicate) { + shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/); + } + + // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); + PipelineParamsK pipeline_params_k; + pipeline_params_k.role = warp_group_idx == 0 + ? MainloopPipelineK::ThreadCategory::Producer + : MainloopPipelineK::ThreadCategory::Consumer; + if constexpr (Use_TMA_KV) { + pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + pipeline_params_k.is_leader = warp_group_thread_idx == 0; + pipeline_params_k.num_consumers = NumMmaThreads; + } else { + pipeline_params_k.consumer_arv_count = NumMmaThreads; + pipeline_params_k.producer_arv_count = NumProducerThreads; + } + + MainloopPipelineK pipeline_k = [&] { + if constexpr (Use_TMA_KV) { + return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{}); + } else { + return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k); + } + }(); + // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{}); + MainloopPipelineV pipeline_v = [&] { + if constexpr (!Transpose_V) { + // TODO: set v instead of k + pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + static_assert(is_same_v); + if constexpr (Use_TMA_KV) { + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k, ClusterShape{}); + } else { + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k); + } + } else { + PipelineParamsV pipeline_params_v; + pipeline_params_v.role = warp_group_idx == 0 + ? MainloopPipelineV::ThreadCategory::Producer + : MainloopPipelineV::ThreadCategory::Consumer; + pipeline_params_v.producer_arv_count = NumProducerThreads; + pipeline_params_v.consumer_arv_count = NumMmaThreads; + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); + } + }(); + static_assert(is_same_v); + // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then + // the producer WG will read from pipeline_vt and write to pipeline_v. + // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used. + // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers. + // However, the thread role isn't used in the pipeline implementation. + MainloopPipelineVt pipeline_vt = [&] { + if constexpr (Use_TMA_KV) { + pipeline_params_k.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k, ClusterShape{}); + } else { + pipeline_params_k.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k); + } + }(); + + PipelineParamsKVNew pipeline_params_kv_new; + pipeline_params_kv_new.role = warp_group_idx == 0 + ? MainloopPipelineKVNew::ThreadCategory::Producer + : MainloopPipelineKVNew::ThreadCategory::Consumer; + pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0; + pipeline_params_kv_new.num_consumers = NumMmaThreads; + auto pipeline_k_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr); + auto pipeline_v_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr); + + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue; + + // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + if (warp_group_idx == 0) { // Producer + cutlass::arch::warpgroup_reg_dealloc(); + + // The pipelines for AppendKV and main attention are different, since e.g. main attention + // might use cp.async to load KV (if PagedKV) while AppendKV always uses TMA to load + // KV_new. Since the pipeline states are different, we have to manually sync to make + // sure the two pipelines don't race when accessing smem_k and smem_v. + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_new = cutlass::make_producer_start_state(); + int work_idx = 0; + + TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; + if constexpr (SingleProducerWarp) { + if (warp_idx_in_warpgroup != 0) { return; } + } + if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); } + + // Load Q, K, V + for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { + + auto block_coord = work_tile_info.get_block_coord(params.scheduler); + SeqlenInfo_t seqlen_info{ + get<2>(block_coord) /*bidb*/, + get<0>(params.mainloop.shape_Q), + !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + get<0>(params.mainloop.shape_K_new), + params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, + params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + }; + auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() { + scheduler.prefetch_next_work(params.scheduler, work_tile_info); + }; + // pipeline_vt won't be used if we don't need to transpose V. + collective_mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, + shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx); + } + collective_mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx); + } else { // Consumer + cutlass::arch::warpgroup_reg_alloc(); + + TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); + // Initialize matmul objects. + TiledMma1 tiled_mma1; + + PipelineState smem_pipe_read; + PipelineState smem_pipe_read_new; + + scheduler.init_consumer(); + collective_mainloop.mma_init(); + + int work_idx = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); + work_tile_info.is_valid(params.scheduler); + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); + float softmax_scale_log2 = params.mainloop.softmax_scale_log2; + // If there's tanh softcap, the scaling will be done before tanh. + auto block_coord = work_tile_info.get_block_coord(params.scheduler); + int const bidb = get<2>(block_coord); + if constexpr (Is_FP8 && !Has_softcap) { + int const bidh = get<1>(block_coord); + int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; + float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; + float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; + softmax_scale_log2 *= q_descale * k_descale; + } + flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2); + + SeqlenInfo_t seqlen_info{ + bidb, + get<0>(params.mainloop.shape_Q), + !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + get<0>(params.mainloop.shape_K_new), + params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, + params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + }; + bool tile_valid = collective_mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + if (tile_valid) { + collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma1, + threadIdx.x - MmaThreadOffset, block_coord); + } else { + collective_epilogue.template store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); + } + } + collective_epilogue.store_tail(); + } + + } + +}; + +} // namespace flash diff --git a/hopper/src/include/flash_qpack_kernel.h b/hopper/src/include/flash_qpack_kernel.h new file mode 100644 index 0000000..9de87a6 --- /dev/null +++ b/hopper/src/include/flash_qpack_kernel.h @@ -0,0 +1,369 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "qpack.h" +#include "dequantize.h" +// #include "include/softmax.h" +// #include "include/mask.h" +// #include "include/dropout.h" +// #include "include/rotary.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_qpack_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int blockN_idx) { + + using Element = typename Kernel_traits::Element; + using ElementKVPack = typename Kernel_traits::ElementKVPack; + using SharedStorage = typename Kernel_traits::SharedStorage; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + SharedStorage& shared_storage = *reinterpret_cast(smem_); + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kBlockP = Kernel_traits::kBlockP; + constexpr int kBlockK_params = Kernel_traits::kBlockK_params; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kHeadDim_pack = Kernel_traits::kHeadDim_pack; + constexpr int kHeadDim_k = Kernel_traits::kHeadDim_k; + constexpr int kHeadDim_k_params = Kernel_traits::kHeadDim_k_params; + constexpr int kHeadDim_v_params = Kernel_traits::kHeadDim_v_params; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int tile_paramsk_j = Kernel_traits::tile_paramsk_j; + constexpr int tile_paramsk_k = Kernel_traits::tile_paramsk_k; + constexpr int tile_paramsk_g = Kernel_traits::tile_paramsk_g; + constexpr int tile_paramsv_k = Kernel_traits::tile_paramsv_k; + constexpr int num_bits = Kernel_traits::num_bits; + constexpr int group_size = Kernel_traits::group_size; + constexpr int num_params = Kernel_traits::num_params; + + const BlockInfo binfo(params, bidb); + + const int bidb_cache = bidb; + const int *block_table = nullptr; + const int block_table_idx = 0; + const int block_table_offset = 0; + const int block_table_offset_pack = 0; + + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + + blockN_idx * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_k_pack = binfo.k_offset(params.K_pack_batch_stride, params.K_pack_row_stride, bidb_cache) + + blockN_idx * kBlockP * params.K_pack_row_stride + (bidh / params.h_h_k_ratio) * params.K_pack_head_stride; + const index_t row_offset_k_params = binfo.k_offset(params.k_params_batch_stride, params.k_params_row_stride, bidb) + + blockN_idx * kBlockK_params * params.k_params_row_stride + (bidh / params.h_h_k_ratio) * params.k_params_head_stride; + + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + + blockN_idx * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_v_pack = binfo.k_offset(params.v_pack_batch_stride, params.v_pack_row_stride, bidb_cache) + + blockN_idx * kBlockN * params.v_pack_row_stride + (bidh / params.h_h_k_ratio) * params.v_pack_head_stride; + const index_t row_offset_v_params = binfo.k_offset(params.v_params_batch_stride, params.v_params_row_stride, bidb) + + blockN_idx * kBlockN * params.v_params_row_stride + (bidh / params.h_h_k_ratio) * params.v_params_head_stride; + + // Tensor, global memory + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gK_pack = make_tensor(make_gmem_ptr(reinterpret_cast(params.K_pack_ptr) + row_offset_k_pack), + Shape, Int>{}, + make_stride(params.K_pack_row_stride, _1{})); + Tensor gK_params = make_tensor(make_gmem_ptr(reinterpret_cast<__half2*>(params.k_params_ptr) + row_offset_k_params), + Shape, Int>{}, + make_stride(params.k_params_row_stride, params.k_params_dim_stride)); + + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gV_pack = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_pack_ptr) + row_offset_v_pack), + Shape, Int>{}, + make_stride(params.v_pack_row_stride, _1{})); + Tensor gV_params = make_tensor(make_gmem_ptr(reinterpret_cast<__half2*>(params.v_params_ptr) + row_offset_v_params), + Shape, Int>{}, + make_stride(params.v_params_row_stride, params.v_params_dim_stride)); + + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_K.data()), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_V.data()), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + Tensor sK_pack = make_tensor(make_smem_ptr(shared_storage.smem_Kpack.data()), typename Kernel_traits::SmemLayoutKPack{}); + Tensor sK_pack_transposed = make_tensor(sK_pack.data(), typename Kernel_traits::SmemLayoutKPacktransposed{}); + Tensor sV_pack = make_tensor(make_smem_ptr(shared_storage.smem_Vpack.data()), typename Kernel_traits::SmemLayoutVPack{}); + Tensor sVt_pack = make_tensor(sV_pack.data(), typename Kernel_traits::SmemLayoutVPacktransposed{}); + Tensor sVtNoSwizzle_pack = make_tensor(sV_pack.data().get(), typename Kernel_traits::SmemLayoutVPacktransposedNoSwizzle{}); + + Tensor sReduce_tmp = make_tensor(make_smem_ptr(shared_storage.smem_reduce_tmp.data()), typename Kernel_traits::SmemLayoutReduce_tmp{}); + + // + // copy: global - shared + // + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::GmemTileCopyK_Pack gmem_tiled_copy_k_pack; + auto gmem_thr_copy_k_pack = gmem_tiled_copy_k_pack.get_thread_slice(tidx); + Tensor tKsK_pack_s2g = gmem_thr_copy_k_pack.partition_S(sK_pack); + Tensor tKgK_pack_s2g = gmem_thr_copy_k_pack.partition_D(gK_pack); + Tensor tKgK_pack_g2s = gmem_thr_copy_k_pack.partition_S(gK_pack); + Tensor tKsK_pack_g2s = gmem_thr_copy_k_pack.partition_D(sK_pack); + + typename Kernel_traits::GmemTileCopyV_Pack gmem_tiled_copy_v_pack; + auto gmem_thr_copy_v_pack = gmem_tiled_copy_v_pack.get_thread_slice(tidx); + Tensor tVsV_pack_s2g = gmem_thr_copy_v_pack.partition_S(sV_pack); + Tensor tVgV_pack_s2g = gmem_thr_copy_v_pack.partition_D(gV_pack); + Tensor tVgV_pack_g2s = gmem_thr_copy_v_pack.partition_S(gV_pack); + Tensor tVsV_pack_g2s = gmem_thr_copy_v_pack.partition_D(sV_pack); + + // + // Tensor: Register per thread + // + + typename Kernel_traits::TiledMma tiled_mma; + typename Kernel_traits::TiledMmaK_i4 tiled_mma_i4; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + auto thr_mma_i4 = tiled_mma_i4.get_thread_slice(tidx); + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tSrK_dequant = thr_mma.partition_fragment_B(sK); + Tensor tSrK_pack_tmp = thr_mma_i4.partition_fragment_B(sK_pack_transposed); // (MMA,MMA_N,MMA_K) + Tensor tSrK_pack = make_fragment_like(tSrK_pack_tmp); + + Tensor tSrV = thr_mma.partition_fragment_B(sVtNoSwizzle); + Tensor tSrV_dequant = thr_mma.partition_fragment_B(sVtNoSwizzle); + Tensor tSrV_pack_tmp = thr_mma_i4.partition_fragment_B(sVtNoSwizzle_pack); + Tensor tSrV_pack = make_fragment_like(tSrV_pack_tmp); + + // + // copy: shared - register + // + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + Tensor tSrK_view = smem_thr_copy_K.retile_D(tSrK); + + auto smem_tiled_copy_kv_pack = make_tiled_copy_B(typename Kernel_traits::R2SCopyAtomPack{}, tiled_mma_i4); + auto smem_thr_copy_kv_pack = smem_tiled_copy_kv_pack.get_thread_slice(tidx); + Tensor tSrK_pack_r2s_view = smem_thr_copy_kv_pack.retile_S(tSrK_pack); + Tensor tSsK_pack_r2s = smem_thr_copy_kv_pack.partition_D(sK_pack); + Tensor tSrV_pack_r2s_view = smem_thr_copy_kv_pack.retile_S(tSrV_pack); + Tensor tSsV_pack_r2s = smem_thr_copy_kv_pack.partition_D(sVt_pack); + Tensor tSsK_pack_s2r = smem_thr_copy_kv_pack.partition_S(sK_pack); + Tensor tSrK_pack_s2r_view = smem_thr_copy_kv_pack.retile_D(tSrK_pack); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_tiled_copy_V_pack = make_tiled_copy_B(typename Kernel_traits::R2SCopyAtomPack{}, tiled_mma_i4); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + auto smem_thr_copy_V_pack = smem_tiled_copy_V_pack.get_thread_slice(tidx); + Tensor tSsV = smem_thr_copy_V.partition_S(sVt); + Tensor tSrV_view = smem_thr_copy_V.retile_D(tSrV); + Tensor tSsV_pack_s2r = smem_thr_copy_V_pack.partition_S(sVt_pack); + Tensor tSrV_pack_s2r_view = smem_thr_copy_V_pack.retile_D(tSrV_pack); + + // Advance gK + cute::copy(gmem_tiled_copy_QKV, tKgK, tKsK); + + cute::cp_async_fence(); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + cute::copy(gmem_tiled_copy_QKV, tVgV, tVsV); + cute::cp_async_fence(); + + cute::copy(smem_tiled_copy_K, tSsK, tSrK_view); + + // quantize kv + using TensorParamsKC = decltype(make_tensor(make_shape(Int<4 * num_params>{}, Int{}))); + using TensorParamsVG = decltype(make_tensor(make_shape(Int{}, Int{}))); // TODO: need to change, hardcode num_bits + using TensorParamsG = decltype(make_tensor(make_shape(Int{}))); + + TensorParamsKC tScales_k_c, tZeros_k_c; + TensorParamsVG tScales_v_c, tZeros_v_c; + TensorParamsG tScales_k_g, tZeros_k_g; + + if (Kernel_traits::quant_mode == 1) { + quant::qpack_Kchannel_Vtensor(tSrK, tSrK_pack, tScales_k_c, tZeros_k_c, sReduce_tmp, num_params); + } else { + quant::quant_Ktensor(tSrK, tSrK_pack, tScales_k_g, tZeros_k_g, num_params); + } + + auto tScales_k_h2_c = cute::recast<__half2>(tScales_k_c); + auto tZeros_k_h2_c = cute::recast<__half2>(tZeros_k_c); + auto tScales_k_h2_g = cute::recast<__half2>(tScales_k_g); + auto tZeros_k_h2_g = cute::recast<__half2>(tZeros_k_g); + + auto tScales_v_h2 = cute::recast<__half2>(tScales_v_c); + auto tZeros_v_h2 = cute::recast<__half2>(tZeros_v_c); + + flash::cp_async_wait<0>(); + __syncthreads(); + cute::copy(smem_tiled_copy_V, tSsV, tSrV_view); + + quant::qpack_Kchannel_Vtensor(tSrV, tSrV_pack, tScales_v_c, tZeros_v_c, sReduce_tmp, num_params); + + const int num_params_2 = num_bits == 2 ? num_params / 2 : num_params; + CUTE_UNROLL + for (int i = 0; i < size<1>(tScales_v_h2); ++i) { + CUTE_UNROLL + for (int j = 0; j < size<0>(tScales_v_h2); ++j) { + gV_params(128 * (i / 8) + 0 + 8 * (i % 8) + 4 * (j / num_params_2) + tidx % 4, j % num_params_2) = tScales_v_h2(j, i); + gV_params(128 * (i / 8) + 64 + 8 * (i % 8) + 4 * (j / num_params_2) + tidx % 4, j % num_params_2) = tZeros_v_h2(j, i); + } + } + + if (Kernel_traits::quant_mode == 1) { + CUTE_UNROLL + for (int i = 0; i < size<1>(tScales_k_h2_c); ++i) { + CUTE_UNROLL + for (int j = 0; j < size<0>(tScales_k_h2_c); ++j) { + gK_params(j % num_params, 0 + 8 * i + 4 * (j / num_params) + tidx % 4) = tScales_k_h2_c(j, i); + gK_params(j % num_params, 64 + 8 * i + 4 * (j / num_params) + tidx % 4) = tZeros_k_h2_c(j, i); + } + } + } else { + CUTE_UNROLL + for (int j = 0; j < size<0>(tScales_k_h2_g); ++j) { + gK_params(0 + 32 * (j / num_params) + tidx / 4, j % num_params) = tScales_k_h2_g(j); + gK_params(64 + 32 * (j / num_params) + tidx / 4, j % num_params) = tZeros_k_h2_g(j); + } + } + + // copy from register to shared memory + cute::copy(smem_tiled_copy_kv_pack, tSrK_pack_r2s_view, tSsK_pack_r2s); + __syncthreads(); + if (kHeadDim == 128 && num_bits == 2) { + if (tidx < 64) { + cute::copy(smem_tiled_copy_kv_pack, tSrV_pack_r2s_view, tSsV_pack_r2s); + } + } else { + cute::copy(smem_tiled_copy_kv_pack, tSrV_pack_r2s_view, tSsV_pack_r2s); + } + + // copy from shared to global + __syncthreads(); + cute::copy(gmem_tiled_copy_k_pack, tKsK_pack_s2g, tKgK_pack_s2g); + __syncthreads(); + cute::copy(gmem_tiled_copy_v_pack, tVsV_pack_s2g, tVgV_pack_s2g); + __syncthreads(); + + ////////////////////////////////////////////////////////////////////////////// + // verify the quantize + // clear(tSrK_pack); + // clear(tSsK_pack_r2s); + // clear(tSrV_pack); + // clear(tSsV_pack_r2s); + + // __syncthreads(); + // cute::copy(gmem_tiled_copy_k_pack, tKgK_pack_g2s, tKsK_pack_g2s); + // cute::copy(gmem_tiled_copy_v_pack, tVgV_pack_g2s, tVsV_pack_g2s); + + // __syncthreads(); + // cute::copy(smem_tiled_copy_kv_pack, tSsK_pack_s2r, tSrK_pack_s2r_view); + // cute::copy(smem_tiled_copy_V_pack, tSsV_pack_s2r, tSrV_pack_s2r_view); + + // __syncthreads(); + // clear(tScales_k_h2_c); + // clear(tZeros_k_h2_c); + // clear(tScales_k_h2_g); + // clear(tZeros_k_h2_g); + + // clear(tScales_v_h2); + // clear(tZeros_v_h2); + + // CUTE_UNROLL + // for (int i = 0; i < size<1>(tScales_v_h2); ++i) { + // CUTE_UNROLL + // for (int j = 0; j < size<0>(tScales_v_h2); ++j) { + // tScales_v_h2(j, i) = gV_params(128 * (i / 8) + 0 + 8 * (i % 8) + 4 * (j / num_params_2) + tidx % 4, j % num_params_2); + // tZeros_v_h2(j, i) = gV_params(128 * (i / 8) + 64 + 8 * (i % 8) + 4 * (j / num_params_2) + tidx % 4, j % num_params_2); + // } + // } + + CUTE_UNROLL + for (int i = 0; i < size<2>(tSrV_pack); ++i) { + quant::dequant_Kchannel_Vtensor(tSrV_pack(_,_,i), tSrV_dequant(_,_,i), tScales_v_c(_,i), tZeros_v_c(_,i), num_params); + } + + if (Kernel_traits::quant_mode == 1) { + // CUTE_UNROLL + // for (int i = 0; i < size<1>(tScales_k_h2_c); ++i) { + // CUTE_UNROLL + // for (int j = 0; j < size<0>(tScales_k_h2_c); ++j) { + // tScales_k_h2_c(j, i) = gK_params(j % num_params, 0 + 8 * i + 4 * (j / num_params) + tidx % 4); + // tZeros_k_h2_c(j, i) = gK_params(j % num_params, 64 + 8 * i + 4 * (j / num_params) + tidx % 4); + // } + // } + + CUTE_UNROLL + for (int i = 0; i < size<2>(tSrK_pack); ++i) { + quant::dequant_Kchannel_Vtensor(tSrK_pack(_,_,i), tSrK_dequant(_,_,i), tScales_k_c(_,i), tZeros_k_c(_,i), num_params); + } + } else { + // CUTE_UNROLL + // for (int j = 0; j < size<0>(tScales_k_h2_g); ++j) { + // tScales_k_h2_g(j) = gK_params(0 + 32*j + tidx/4, 0); + // tZeros_k_h2_g(j) = gK_params(64 + 32*j + tidx/4, 0); + // } + + // auto tScales_k_h1_g = cute::recast<__half>(tScales_k_h2_g); + // auto tZeros_k_h1_g = cute::recast<__half>(tZeros_k_h2_g); + + // CUTE_UNROLL + // for (int i = 0; i < size<2>(tSrK_pack); ++i) { + // quant::dequantize_Ktensor(tSrK_pack, tSrK_dequant, tScales_k_h2_g, tZeros_k_h2_g, 4, group_size, i); + // } + } + + ////////////////////////////////////////////////////////////////////////////// + #if DEBUG2 + if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + PRINT("tSrK", tSrK.layout()); PRINTTENSOR("tSrK", tSrK); + PRINT("tSrK_dequant", tSrK_dequant.layout()); PRINTTENSOR("tSrK_dequant", tSrK_dequant); + PRINT("gK_pack", gK_pack.layout()); PRINTTENSOR("gK_pack", gK_pack); + // auto gK_params_f = cute::recast(gK_params); + // PRINT("gK_params", gK_params.layout()); PRINTTENSOR("gK_params", gK_params_f); + printf("#####################################################################################\n"); + } + #endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_qpack(const Params ¶ms) { + // The block index for the number of blocks. + const int blockN_idx = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + flash::compute_qpack_1rowblock(params, bidb, bidh, blockN_idx); +} + + +} // namespace flash \ No newline at end of file diff --git a/hopper/src/include/heuristics.h b/hopper/src/include/heuristics.h new file mode 100644 index 0000000..8e7b4a3 --- /dev/null +++ b/hopper/src/include/heuristics.h @@ -0,0 +1,48 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) { + // If varlen, we don't actually know seqlen_q but only max_seqlen_q. + if (varlen_q) return true; + // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM + auto round_up = [](int a, int b) { return (a + b - 1) / b * b; }; + float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM)); + float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM)); + return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency; +}; + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } + // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. + if (num_n_blocks <= 4) { return 1; } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} diff --git a/hopper/src/include/kernel_traits.h b/hopper/src/include/kernel_traits.h new file mode 100644 index 0000000..2b99e86 --- /dev/null +++ b/hopper/src/include/kernel_traits.h @@ -0,0 +1,176 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = int64_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; +#else + using MMA_Atom_Arch = MMA_Atom; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template > +struct Flash_qpack_traits : public Base { + using Element = typename Base::Element; + using ElementKVPack = cute::uint16_t; + using index_t = typename Base::index_t; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr int quant_mode = quantmode_; + static constexpr int group_size = group_size_; + static constexpr int num_bits = num_bits_; + static constexpr int pack_num = 16 / num_bits; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockN = kBlockN_; + static constexpr int kBlockN_pack = num_bits == 4 ? 128 : 256; + static constexpr int kBlockP = quant_mode == 1 ? kBlockN / pack_num : kBlockN; + static constexpr int kBlockK_params = quant_mode == 1 ? kBlockN / group_size : kBlockN; + static constexpr int kHeadDim = kHeadDim_; + static constexpr int kHeadDim_pack = kHeadDim / pack_num; // TODO + static constexpr int kHeadDim_k = quant_mode == 1 ? kHeadDim : kHeadDim_pack; + static constexpr int kHeadDim_k_params = quant_mode == 1 ? kHeadDim : kHeadDim / group_size; + static constexpr int kHeadDim_v_params = kHeadDim / group_size; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int tile_paramsk_g = kBlockN / 32 * (kBlockN / group_size); // TODO: check + static constexpr int tile_paramsk_j = kBlockN / group_size; + static constexpr int tile_paramsk_k = kHeadDim / 16; + static constexpr int tile_paramsv_k = kBlockN / 16; // TODO: check 128 + + static constexpr int num_params = kBlockN_pack / group_size; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_4,_1>>, + Tile, _128, _16>>; + + using TiledMmaK_i4 = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_4,_1>>, + Tile, Int<32>, _16>>; + + using SmemLayoutAtomKV_SW = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutAtomK_tiled = decltype( + make_layout(make_shape(Int<8>{}, Int{}), + make_stride(Int{}, Int<1>{}))); + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomK_tiled{}, + Shape, Int>{})); + using SmemLayoutVtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + + using SmemLayoutKPack = decltype( + make_layout(make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}))); + using SmemLayoutKPacktransposed_ = decltype( + composition(SmemLayoutKPack{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutKPacktransposed = std::conditional_t< + quant_mode == 1, + SmemLayoutKPack, + SmemLayoutKPacktransposed_ + >; + + using SmemLayoutAtomV = decltype( + make_layout(make_shape(Int<8>{}, Int{}), + make_stride(Int{}, Int<1>{}))); + using SmemLayoutVPack = decltype(tile_to_shape( + SmemLayoutAtomV{}, + Shape, Int>{})); + using SmemLayoutVPacktransposed = decltype( + composition(SmemLayoutVPack{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVPacktransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVPacktransposed{})); + + // TODO: 32 of x can be determined. + using SmemLayoutReduce_tmp = decltype( + make_layout(make_shape(Int<32>{}, Int<32>{}), + make_stride(Int<32>{}, Int<1>{}))); + + using R2SCopyAtom = Copy_Atom; + using R2SCopyAtomPack = Copy_Atom; + + struct SharedStorage + { + array_aligned> smem_K; + array_aligned> smem_V; + array_aligned> smem_Kpack; + array_aligned> smem_Vpack; + array_aligned> smem_reduce_tmp; + }; + static constexpr int kSmemSize = int(sizeof(SharedStorage)); + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTileCopyK_Pack = decltype( + make_tiled_copy(Copy_Atom{}, + make_layout(make_shape(_32{}, _4{}), make_stride(_4{}, _1{})), + Layout>{})); // Val layout, 8 vals per store + using GmemTileCopyV_Pack = decltype( + make_tiled_copy(Copy_Atom{}, + make_layout(make_shape(_64{}, _2{}), make_stride(_2{}, _1{})), + Layout>{})); // Val layout, 8 vals per store +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/hopper/src/include/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/src/include/mainloop_fwd_sm90_tma_gmma_ws.hpp new file mode 100644 index 0000000..f121b86 --- /dev/null +++ b/hopper/src/include/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -0,0 +1,1165 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +#include "flash.h" +#include "named_barrier.hpp" +#include "seqlen.h" +#include "mask.h" +#include "pack_gqa.h" +#include "paged_kv.h" +#include "rotary.h" +#include "utils.h" +#include "dequantize.h" +#include "sm90_pipeline_no_cluster.hpp" + +namespace flash { + +using namespace cute; + +template +struct CollectiveMainloopFwdSm90 { + + static constexpr int kStages = Stages; + using index_t = int64_t; + using ClusterShape = ClusterShape_; + using TileShape_MNK = TileShape_MNK_; + using Element = Element_; + using ElementKVPack = cute::uint16_t; + using ElementAccum = ElementAccum_; + using ArchTag = ArchTag_; + static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; + static constexpr bool Is_causal = Is_causal_; + static constexpr bool Is_local = Is_local_; + static constexpr bool Has_softcap = Has_softcap_; + static constexpr bool Varlen = Varlen_; + static constexpr bool PagedKV = PagedKV_; + static constexpr bool AppendKV = AppendKV_; + static constexpr bool PackGQA = PackGQA_; + static constexpr bool Split = Split_; + static constexpr bool V_colmajor = V_colmajor_; + static constexpr bool Transpose_V = Is_FP8 && !V_colmajor; + static constexpr bool Use_TMA_Q = !PackGQA; + static constexpr bool Use_TMA_KV = !PagedKV; + + static constexpr int quant_mode = quant_mode_; + static constexpr int group_size = group_size_; + static constexpr int num_bits = num_bits_; + + static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); + static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); + using SeqlenInfo_t = flash::SeqlenInfoQKNewK; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 && !V_colmajor ? GMMA::Major::MN : GMMA::Major::K; + static constexpr cute::GMMA::Major TmaMajorV = !V_colmajor ? GMMA::Major::MN : GMMA::Major::K; + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + + static constexpr int pack_num = 16 / num_bits; + // static constexpr int kBlockN_Qpack = num_bits == 4 ? 128 : 256; + static constexpr int kBlockN_Qpack = 128; + static constexpr int kBlockN_pack = quant_mode == 1 ? kBlockN / pack_num : kBlockN; + static constexpr int kBlockN_params = quant_mode == 1 ? kBlockN / group_size : kBlockN; + static constexpr int kHeadDim_kpack = quant_mode == 1 ? kHeadDim : kHeadDim / pack_num; + static constexpr int kHeadDim_vpack = kHeadDim / pack_num; + static constexpr int kHeadDim_k_params = quant_mode == 1 ? kHeadDim : kHeadDim / group_size; + static constexpr int kHeadDim_v_params = kHeadDim / group_size; + static constexpr int tile_paramsk_j = kBlockN / group_size; + static constexpr int tile_paramsk_m = kBlockN / kBlockN_Qpack; + static constexpr int tile_paramsk_g = kBlockN / 32 * (kBlockN / group_size); // TODO: check + static constexpr int tile_paramsk_k = kHeadDim / 16; + static constexpr int tile_paramsv_k = kBlockN / 16; + + static constexpr int num_params = kBlockN_Qpack / group_size; // TODO: check 128 + + using TileShape_MNK_Kpack = decltype(make_shape(Int{}, Int{}, Int{})); + using TileShape_MNK_Vpack = decltype(make_shape(Int{}, Int{}, Int{})); + + // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. + // Leaving this option here for reference. + static constexpr bool Mma0_is_RS = false; + // We can have Mma1 (P @ V) with P in smem in rmem to reduce register pressure at the cost of more smem. + static_assert(!(!Mma1_is_RS && !IntraWGOverlap), "Mma1 must be RS if IntraWGOverlap is enabled"); + static_assert(!(!Mma1_is_RS && Is_FP8), "Mma1 must be RS if FP8"); + static_assert(!(!Mma1_is_RS && Transpose_V), "Mma1 must be RS if Transpose_V"); + + using AtomLayoutMNK = Layout, _1, _1>>; + using TiledMma0 = decltype(cute::make_tiled_mma( + std::conditional_t< + !Mma0_is_RS, + decltype(cute::GMMA::ss_op_selector()), + decltype(cute::GMMA::rs_op_selector()) + >{}, + AtomLayoutMNK{})); + using TiledMma1 = decltype(cute::make_tiled_mma( + std::conditional_t< + !Mma1_is_RS, + decltype(cute::GMMA::ss_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()), + decltype(cute::GMMA::rs_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()) + >{}, + AtomLayoutMNK{})); + + /* Dequant */ + using TiledMma0_dequant = TiledMMA< + MMA_Atom, + Layout,_4,_1>>, + Tile, Int<32>, _16>>; + using TiledMma0_dequant_r2s = TiledMMA< + MMA_Atom, + Layout,_4,_1>>, + Tile, Int<128>, _16>>; + using S2RCopyAtomKPack = Copy_Atom; + using R2SCopyAtomKDequant = Copy_Atom; + using S2RCopyAtomVPack = Copy_Atom; + using R2SCopyAtomVDequant = Copy_Atom; + + static constexpr int NumMmaThreads = size(TiledMma0{}); + static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; + static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); + static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; + static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutK = decltype(tile_to_shape( + SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomKPack = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_Kpack{})), decltype(cute::get<2>(TileShape_MNK_Kpack{}))>()); + using SmemLayoutKPack = decltype(tile_to_shape( + SmemLayoutAtomKPack{}, + make_shape(shape<1>(TileShape_MNK_Kpack{}), shape<2>(TileShape_MNK_Kpack{}), Int{}))); + + // using SmemLayoutAtomKPack = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_Kpack{})), decltype(cute::get<1>(TileShape_MNK_Kpack{}))>()); + // using SmemLayoutKPack = decltype(tile_to_shape( + // SmemLayoutAtomKPack{}, + // make_shape(shape<1>(TileShape_MNK_Kpack{}), shape<2>(TileShape_MNK_Kpack{}), Int{}))); + + using SmemLayoutKParams_channel = decltype( + composition(Swizzle<2, 2, 3>{}, + Layout, Int>, + Stride, Int<1>>>{})); + using SmemLayoutAtomKParams_group = decltype( + make_layout(make_shape(Int<32>{}, Int<1>{}), + make_stride(Int<1>{}, Int<1>{}))); + using SmemLayoutKParams_group = decltype(tile_to_shape( + SmemLayoutAtomKParams_group{}, + Shape, Int>{})); + using SmemLayoutKParams = SmemLayoutKParams_group; + + using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutVt = decltype(tile_to_shape( + SmemLayoutAtomVt{}, + make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + std::conditional_t, cute::Step<_2, _1, _3>>{})); + using SmemLayoutVtMma = SmemLayoutVt; + + using SmemLayoutAtomVtPack = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_Vpack{})), decltype(cute::get<1>(TileShape_MNK_Vpack{}))>()); + using SmemLayoutVtPack = decltype(tile_to_shape( + SmemLayoutAtomVtPack{}, + make_shape(shape<2>(TileShape_MNK_Vpack{}), shape<1>(TileShape_MNK_Vpack{}), Int{}), + std::conditional_t, cute::Step<_2, _1, _3>>{})); + + using SmemLayoutAtomVParams = decltype( + make_layout(make_shape(Int<32>{}, Int<1>{}), + make_stride(Int<1>{}, Int<1>{}))); + using SmemLayoutVParams = decltype(tile_to_shape( + SmemLayoutAtomVParams{}, + Shape, Int>{})); + + // Only used if we're using cp.async to load V + using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutVCpAsync = decltype(tile_to_shape( + SmemLayoutAtomVCpAsync{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); + + using SmemCopyAtomP = Copy_Atom; + + // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. + // For FP16/BF16 we don't do any transposing. + static_assert(!Transpose_V || (kHeadDim % 32 == 0 && kBlockN % 32 == 0)); + static constexpr bool kHeadDim_multiple_64 = kHeadDim % 64 == 0; + // Either kHeadDim is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), + // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose). + static_assert(!Transpose_V || (kHeadDim_multiple_64 || kBlockN % 64 == 0)); + using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; + using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; + using LDSM_value_shape = Shape<_2, _2, _1, _4>; + using LDSM_value_stride = Stride<_1, _2, _16, _4>; + using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; + using S2RTiledCopyVt = decltype(make_tiled_copy( + Copy_Atom{}, Layout{}, + Layout{})); + + using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; + using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; + using STSM_value_shape = Shape<_1, _4, _2, _2>; + using STSM_value_stride = Stride<_0, _1, _4, _8>; + using STSM_divide_shape = Shape<_8, _16>; + // These will not permute the columns of V (the kHeadDim dimension) but incur bank conflicts + // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS). + // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue. + // using STSM_value_shape = Shape<_2, _4, _1, _2>; + // using STSM_value_stride = Stride<_4, _1, _0, _8>; + // using STSM_divide_shape = Shape<_16, _16>; + using R2STiledCopyV = decltype(make_tiled_copy( + Copy_Atom{}, Layout{}, + Layout{})); + + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); + + // We use CpAsync for K and V if PagedKV and AppendKV, since TMA doesn't work there + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved + // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will + // load twice from the same row. + static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication + static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRow"); + using GmemTiledCopyAppendKV = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) + using StrideQK = cute::Stride; + // using StrideKParams = cute::Stride; + using StrideKParams = cute::Stride<_1, int64_t, int64_t, int64_t>; + using StrideVParams = cute::Stride<_1, int64_t, int64_t, int64_t>; + using StrideV = std::conditional_t>; + // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) + using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>>; + using StrideQPacked = std::conditional_t, _1, int64_t, int64_t>>; + using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) + using StridePageTable = cute::Stride; + using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) + using StrideRotary = cute::Stride; + using StrideDescale = cute::Stride; + + using TMA_Q = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQ{}, + TileShape_MNK{}, + ClusterShape{})); + + using TMA_K_pack = decltype(make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + take<0, 2>(SmemLayoutKPack{}), + TileShape_MNK_Kpack{}, + ClusterShape{})); + + // using GmemTileCopyKParams = decltype( + // make_tiled_copy(Copy_Atom, __half2>{}, + // make_layout(make_shape(_1{}, _128{}), make_stride(_1{}, _1{})), + // Layout>{})); // Val layout, 4 vals per store + + using GmemTileCopyKParams = decltype( + make_tiled_copy(Copy_Atom, __half2>{}, + make_layout(make_shape(_128{}, _1{}), make_stride(_1{}, _1{})), + Layout>{})); // Val layout, 4 vals per store + + using GmemTileCopyVParams = decltype( + make_tiled_copy(Copy_Atom, __half2>{}, + make_layout(make_shape(_128{}, _1{}), make_stride(_1{}, _1{})), + Layout>{})); // Val layout, 4 vals per store + + using TMA_V_pack = decltype(make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), + take<0, 2>(SmemLayoutVtPack{}), + select<2, 1>(TileShape_MNK_Vpack{}), + size<0>(ClusterShape{}))); + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK_original = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK_pack = static_cast(size(take<0, 2>(SmemLayoutKPack{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK = TmaTransactionBytesK_pack; + static constexpr uint32_t TmaTransactionBytesV_original = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesV_pack = static_cast(size(take<0, 2>(SmemLayoutVtPack{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesV = TmaTransactionBytesV_pack; + // static_assert(TmaTransactionBytesK == TmaTransactionBytesV); + + using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; + using MainloopPipelineK = std::conditional_t>; + using MainloopPipelineV = std::conditional_t>; + using MainloopPipelineVt = std::conditional_t>; + // We always use TMA for K_new and V_new + using MainloopPipelineKVNew = PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned + // and have sQ being position_independent_swizzle_tensor. + // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned. + static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !AppendKV && !Mma0_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); + static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); + static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); + static constexpr size_t SmemAlignmentVtPack = cutlass::detail::alignment_for_swizzle(SmemLayoutVtPack{}); + static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128 && SmemAlignmentVtPack >= 128, "Require at least 128B alignment"); + static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{}); + static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); + + struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + // cute::array_aligned, SmemAlignmentK> smem_k_pack; + cute::array_aligned<__half2, cute::cosize_v> smem_k_params; + cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; + // cute::array_aligned, SmemAlignmentVtPack> smem_v_pack; + cute::array_aligned<__half2, cute::cosize_v> smem_v_params; + }; + + using TensorStorage = TensorStorageWithoutPNoTranspose; + + // These are tuned for speed. They don't affect correctness. + static constexpr bool UseSchedulerBarrier = IntraWGOverlap + ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) + : NumMmaWarpGroups == 2; + static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor); + + // Host side kernel arguments + struct Arguments { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQK const stride_Q; + + // Element* const ptr_K; // Not Element const* since we might append to KV cache in-place + ShapeQKV const shape_K; + // StrideQK const stride_K; + ElementKVPack* const ptr_K_pack; + ShapeQKV const shape_K_pack; + StrideQK const stride_K_pack; + __half2* const ptr_K_params; + ShapeQKV const shape_K_params; + StrideKParams const stride_K_params; + + // Element* const ptr_V; + // StrideV const stride_V; + ElementKVPack* const ptr_V_pack; + ShapeQKV const shape_V_pack; + StrideQK const stride_V_pack; + __half2* const ptr_V_params; + ShapeQKV const shape_V_params; + StrideVParams const stride_V_params; + + Element const* const ptr_K_new; + ShapeQKV const shape_K_new; + StrideQK const stride_K_new; + Element const* const ptr_V_new; + StrideV const stride_V_new; + Element const* const ptr_rotary_cos; + ShapeRotary const shape_rotary; + StrideRotary const stride_rotary_cos; + Element const* const ptr_rotary_sin; + StrideRotary const stride_rotary_sin; + bool const is_rotary_interleaved; + int const* const ptr_pagetable; + ShapePageTable const shape_pagetable; + StridePageTable const stride_pagetable; + float const softmax_scale; + float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; + StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; + int const window_size_left = -1, window_size_right = -1, sink_token_length = 0; + float const softcap_val; + int const num_splits; + int const* const kv_batch_idx = nullptr; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + }; + + // Device side kernel params + struct Params { + Element const* const ptr_Q; + ShapeQKV const shape_Q; + StrideQK const stride_Q; + ShapeQPacked const shape_Q_packed; + StrideQPacked const stride_Q_packed; + + // Element* const ptr_K; + ShapeQKV const shape_K; + // StrideQK const stride_K; + ElementKVPack* const ptr_K_pack; + ShapeQKV const shape_K_pack; + StrideQK const stride_K_pack; + __half2* const ptr_K_params; + ShapeQKV const shape_K_params; + StrideKParams const stride_K_params; + + // Element* const ptr_V; + // StrideV const stride_V; + ElementKVPack* const ptr_V_pack; + ShapeQKV const shape_V_pack; + StrideQK const stride_V_pack; + __half2* const ptr_V_params; + ShapeQKV const shape_V_params; + StrideVParams const stride_V_params; + + Element const* const ptr_K_new; + ShapeQKV const shape_K_new; + StrideQK const stride_K_new; + Element const* const ptr_V_new; + StrideV const stride_V_new; + Element const* const ptr_rotary_cos; + ShapeRotary const shape_rotary; + StrideRotary const stride_rotary_cos; + Element const* const ptr_rotary_sin; + StrideRotary const stride_rotary_sin; + bool const is_rotary_interleaved; + int const* const ptr_pagetable; + ShapePageTable const shape_pagetable; + StridePageTable const stride_pagetable; + cutlass::FastDivmod page_size_divmod; + cutlass::FastDivmod qhead_per_khead_divmod; + + TMA_Q tma_load_Q; + + // TMA_K tma_load_K; + TMA_K_pack tma_load_K_pack; + + // TMA_V tma_load_V; + TMA_V_pack tma_load_V_pack; + + // TMA_K tma_load_K_new; + // TMA_V tma_load_V_new; + float const softmax_scale_log2; + float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; + StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; + float const softcap_val; + int const window_size_left, window_size_right, sink_token_length; + int const num_splits; + int const* const kv_batch_idx = nullptr; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); + TMA_Q tma_load_Q = make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + mQ, + SmemLayoutQ{}, + TileShape_MNK{}, + ClusterShape{}); // no mcast for Q + + // K + // Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); + // TMA_K tma_load_K = make_tma_copy_B_sm90( + // GmemTiledCopyKV{}, + // mK, + // take<0, 2>(SmemLayoutK{}), + // TileShape_MNK{}, + // ClusterShape{}); // mcast along M mode for this N load, if any + Tensor mK_pack = make_tensor(make_gmem_ptr(args.ptr_K_pack), args.shape_K_pack, args.stride_K_pack); + TMA_K_pack tma_load_K_pack = make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + mK_pack, + take<0, 2>(SmemLayoutKPack{}), + TileShape_MNK_Kpack{}, + ClusterShape{}); + + // V Tensor + // Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), select<1, 0, 2, 3>(args.shape_K), select<1, 0, 2, 3>(args.stride_V)); + // TMA_V tma_load_V = make_tma_copy( + // GmemTiledCopyKV{}, + // mV, + // take<0, 2>(SmemLayoutVt{}), + // select<2, 1>(TileShape_MNK{}), + // size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + // V_pack Tensor + Tensor mV_pack = make_tensor(make_gmem_ptr(args.ptr_V_pack), select<1, 0, 2, 3>(args.shape_V_pack), select<1, 0, 2, 3>(args.stride_V_pack)); + TMA_V_pack tma_load_V_pack = make_tma_copy( + GmemTiledCopyKV{}, + mV_pack, + take<0, 2>(SmemLayoutVtPack{}), + select<2, 1>(TileShape_MNK_Vpack{}), + size<0>(ClusterShape{})); + + // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) + int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); + auto const shape_Q_packed = cute::conditional_return( + args.shape_Q, + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)) + ); + auto const stride_Q_packed = cute::conditional_return( + args.stride_Q, + make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) + ); + if (get<1>(args.shape_rotary) > 0) { + assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); + } + assert(args.num_splits >= 1); + // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + // Right after this, we multiply by log2(e) before applying exp2. + // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) + // (assigning it to params.softmax_scale_log2). + return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, + // args.ptr_K, + args.shape_K, + // args.stride_K, + args.ptr_K_pack, args.shape_K_pack, args.stride_K_pack, + args.ptr_K_params, args.shape_K_params, args.stride_K_params, + // args.ptr_V, args.stride_V, + args.ptr_V_pack, args.shape_V_pack, args.stride_V_pack, + args.ptr_V_params, args.shape_V_params, args.stride_V_params, + args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, + args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, + args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, + args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, + cutlass::FastDivmod(int(get<0>(args.shape_K))), + cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + tma_load_Q, + // tma_load_K, + tma_load_K_pack, + // tma_load_V, + tma_load_V_pack, + // tma_load_K_new, + // tma_load_V_new, + !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), + args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, + args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, + !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, + args.window_size_left, args.window_size_right, args.sink_token_length, + !Split ? 1 : args.num_splits, + args.kv_batch_idx, + args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, + args.seqused_q, args.seqused_k, args.leftpad_k}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + if constexpr (Use_TMA_Q) { + cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); + } + if constexpr (Use_TMA_KV) { + cute::prefetch_tma_descriptor(params.tma_load_K_pack.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_V_pack.get_tma_descriptor()); + } + + } + + CUTLASS_DEVICE + cute::tuple get_n_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, + int m_block, int bidb, int split_idx=0, int num_splits=1) { + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int const seqlen_k = seqlen_info.seqlen_k; + int const seqlen_q = seqlen_info.seqlen_q; + int n_block_max = cute::ceil_div(seqlen_k, kBlockN); + if constexpr (Is_causal || Is_local) { + int m_idx_max = (m_block + 1) * kBlockM; + // TODO: check off-by-1 error + if (PackGQA) { m_idx_max = params.qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } + n_block_max = std::min(n_block_max, + cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + params.window_size_right, kBlockN)); + } + int n_block_min = 0; + if constexpr (Is_local) { + int m_idx_min = m_block * kBlockM; + if (PackGQA) { m_idx_min = params.qhead_per_khead_divmod.divide(m_idx_min); } + n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - params.window_size_left) / kBlockN); + } + // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } + if constexpr (Split) { + int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits); + n_block_min = n_block_min + split_idx * num_n_blocks_per_split; + n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); + } + // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } + return {n_block_min, n_block_max}; + } + + template + CUTLASS_DEVICE void + load(Params const& params, + MainloopPipelineK pipeline_k, + MainloopPipelineV pipeline_v, + MainloopPipelineVt pipeline_vt, + PipelineState& smem_pipe_write, + SharedStorage &shared_storage, + SchedulerPrefetch const& scheduler_prefetch, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + int &work_idx + ) { + + + auto [m_block, bidh, bidb, split_idx] = block_coord; + auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. + + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { + scheduler_prefetch(); + return; + } + } + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + + Tensor sK_pack = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKPack{}); + Tensor sK_params = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k_params.data()), SmemLayoutKParams{}); + + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutVt{}); + Tensor sVt_pack = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtPack{}); + Tensor sV_params = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v_params.data()), SmemLayoutVParams{}); + + // Only used if Transpose_V + Tensor sV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutVtMma{})); + + int const thread_idx = threadIdx.x % NumProducerThreads; + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; + + auto [k_params_row_stride, k_params_dim_stride, k_params_head_stride, k_params_batch_stride] = params.stride_K_params; + auto [v_params_row_stride, v_params_dim_stride, v_params_head_stride, v_params_batch_stride] = params.stride_V_params; + + const index_t row_offset_k_params = bidb_kv * k_params_batch_stride + bidh_kv * k_params_head_stride + + (n_block_max - 1) * kBlockN_params * k_params_row_stride; + const index_t row_offset_v_params = bidb_kv * v_params_batch_stride + bidh_kv * v_params_head_stride + + (n_block_max - 1) * kBlockN * v_params_row_stride; + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + bool const is_varlen_q = Varlen && params.cu_seqlens_q; + bool const is_varlen_k = Varlen && params.cu_seqlens_k; + + Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + + Tensor mK_TMA_pack = params.tma_load_K_pack.get_tma_tensor(params.shape_K_pack)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor mVt_TMA_pack = params.tma_load_V_pack.get_tma_tensor(select<1, 0, 2, 3>(params.shape_V_pack))(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + + Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gK_TMA_pack = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK_TMA_pack), select<1, 2>(TileShape_MNK_Kpack{}), make_coord(_, _0{})); // (N, K, _) + Tensor gK_params = make_tensor(make_gmem_ptr(reinterpret_cast<__half2*>(params.ptr_K_params) + row_offset_k_params), + Shape, Int>{}, + make_stride(k_params_row_stride, k_params_dim_stride)); + + Tensor gVt_TMA_pack = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA_pack), select<2, 1>(TileShape_MNK_Vpack{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gV_params = make_tensor(make_gmem_ptr(reinterpret_cast<__half2*>(params.ptr_V_params) + row_offset_v_params), + Shape, Int>{}, + make_stride(_1{}, v_params_dim_stride)); + + auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); + Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) + Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) + + auto block_tma_K_pack = params.tma_load_K_pack.get_slice(cluster_local_block_id.x); + Tensor tKgK_TMA_pack = group_modes<0, 3>(block_tma_K_pack.partition_S(gK_TMA_pack)); // (TMA, k) + Tensor tKsK_TMA_pack = group_modes<0, 3>(block_tma_K_pack.partition_D(sK_pack)); // (TMA, PIPE) + + auto block_tma_V_pack = params.tma_load_V_pack.get_slice(cluster_local_block_id.x); + Tensor tVgVt_TMA_pack = group_modes<0, 3>(block_tma_V_pack.partition_S(gVt_TMA_pack)); // (TMA, k) + Tensor tVsVt_TMA_pack = group_modes<0, 3>(block_tma_V_pack.partition_D(sVt_pack)); // (TMA, PIPE) + + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; + + uint16_t mcast_mask_kv = 0; + + auto load_K = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { + pipeline_k.producer_acquire(smem_pipe_write); + copy(params.tma_load_K_pack.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), + tKgK_TMA_pack(_, n_block), tKsK_TMA_pack(_, smem_pipe_write.index())); + }; + + auto load_V = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { + auto pipeline_v_load = cute::conditional_return(pipeline_v, pipeline_vt); + pipeline_v_load.producer_acquire(smem_pipe_write); + copy(params.tma_load_V_pack.with(*pipeline_v_load.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), + tVgVt_TMA_pack(_, n_block), tVsVt_TMA_pack(_, smem_pipe_write.index())); + + }; + + int n_block = n_block_max - 1; + + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + // If this is true, we're guaranteed that only the first warp will execute this function + static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; + bool should_load_KV = !Use_TMA_KV || ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()); + + if (should_load_KV) { + if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } + load_K(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); + } + + // load Q + cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + Tensor mQ_s = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); + using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; + PackGQAt::load_Q(mQ_s, sQ_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); + auto &barrier_Q = shared_storage.pipelines.barrier_Q; + + cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Q)); + barrier_Q.arrive(); + + // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem + // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the + // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O. + // if (thread_idx == 0) { printf("Producer: main load, before barrier_O, work_idx = %d\n", work_idx);} + shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); + // if (thread_idx == 0) { printf("Producer: main load, after barrier_O\n");} + + if constexpr (!Transpose_V && !IntraWGOverlap) { + if (should_load_KV) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } + } + int n_block_prev = n_block; + --n_block; + #pragma unroll (!Transpose_V && Use_TMA_KV ? 2 : 1) + for (; n_block >= n_block_min; --n_block) { + PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind + ++smem_pipe_write; + if (should_load_KV) { + if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } + load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); + if constexpr (!Transpose_V) { + if constexpr (IntraWGOverlap) { + load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/); + } else { + load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); + } + } + } + n_block_prev = n_block; + } + + scheduler_prefetch(); + if constexpr (!Transpose_V && IntraWGOverlap) { + if (should_load_KV) { load_V(n_block_prev, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } + } + ++smem_pipe_write; + // At the end, all threads have the correct smem_pipe_write. + ++work_idx; + } + + template + CUTLASS_DEVICE void + load_tail(MainloopPipelineK pipeline_k, MainloopPipelineV pipeline_v, MainloopPipelineVt pipeline_vt, + PipelineState& smem_pipe_write, SharedStorage &shared_storage, int const work_idx) { + + shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if (warp_idx_in_warpgroup == 0 && cute::elect_one_sync()) { + pipeline_k.producer_tail(smem_pipe_write); + pipeline_v.producer_tail(smem_pipe_write); + if constexpr (Transpose_V) { pipeline_vt.producer_tail(smem_pipe_write); } + } + } + + CUTLASS_DEVICE void + warp_scheduler_barrier_sync() { + if constexpr (UseSchedulerBarrier) { + cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); + } + } + + CUTLASS_DEVICE void + warp_scheduler_barrier_arrive() { + if constexpr (UseSchedulerBarrier) { + static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + int const cur_WG = flash::canonical_warp_group_idx_nosync() - 1; + int const next_WG = NumMmaWarpGroups == 2 + ? 1 - cur_WG + : (cur_WG < NumMmaWarpGroups - 1 ? cur_WG + 1 : 0); + cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) + next_WG /*id*/); + } + } + + CUTLASS_DEVICE void + mma_init() { + // Tell producers that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::KParamsEmpty) /*id*/); + if constexpr (UseSchedulerBarrier) { + // We have NamedBarrier for up to 3 WGs + static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); + // WG1 needs the very first signal to start + if (flash::canonical_warp_group_idx_nosync() == 1) { + cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); + } + } + } + + template + CUTLASS_DEVICE bool + mma(Params const& params, + MainloopPipelineK pipeline_k, + MainloopPipelineV pipeline_v, + PipelineState& smem_pipe_read, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + int &work_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + + #if DEBUG + if (threadIdx.x == 128 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + printf("n_block_min: %d, n_block_max: %d\n", n_block_min, n_block_max); + } + #endif + + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { return false; } + } + + auto [k_params_row_stride, k_params_dim_stride, k_params_head_stride, k_params_batch_stride] = params.stride_K_params; + auto [v_params_row_stride, v_params_dim_stride, v_params_head_stride, v_params_batch_stride] = params.stride_V_params; + + const index_t row_offset_k_params = bidb_kv * k_params_batch_stride + bidh_kv * k_params_head_stride + + (n_block_max - 1) * kBlockN_params * k_params_row_stride; + const index_t row_offset_v_params = bidb_kv * v_params_batch_stride + bidh_kv * v_params_head_stride + + (n_block_max - 1) * kBlockN * v_params_row_stride; + + Tensor gK_params = make_tensor(make_gmem_ptr(reinterpret_cast<__half2*>(params.ptr_K_params) + row_offset_k_params), + Shape, Int>{}, + make_stride(k_params_row_stride, k_params_dim_stride)); + Tensor gV_params = make_tensor(make_gmem_ptr(reinterpret_cast<__half2*>(params.ptr_V_params) + row_offset_v_params), + Shape, Int>{}, + make_stride(_1{}, v_params_dim_stride)); + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); + + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sK_pack = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKPack{}); + Tensor sK_params = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k_params.data()), SmemLayoutKParams{}); + + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); + Tensor sV_pack = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtPack{}); + Tensor sV_params = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v_params.data()), SmemLayoutVParams{}); + + Tensor sP = [&] { + if constexpr (Mma1_is_RS) { + // We might not have smem_p if !Mma1_is_RS1, just use smem_q as a placeholder since we don't use it + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutP{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); + } + }(); + + if constexpr (!Mma0_is_RS) { + static_assert(stride<0>(typename TiledMma0::ALayout{}) == 0 and + stride<0>(typename TiledMma0::BLayout{}) == 0 and + size<0>(typename TiledMma0::ALayout{}) == cutlass::NumThreadsPerWarpGroup and + size<0>(typename TiledMma0::BLayout{}) == cutlass::NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + } + constexpr int MmaWarpGroups = size(TiledMma0{}) / cutlass::NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(make_shape(Int{}), + make_stride(Int{})); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + TiledMma0 tiled_mma0; + TiledMma1 tiled_mma1; + auto wg_mma0 = tiled_mma0.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma1 = tiled_mma1.get_slice(warp_group_thread_layout(warp_group_idx)); + + auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma0); + auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); + + // Allocate "fragments/descriptors" + Tensor tSrQ = wg_mma0.partition_fragment_A(sQ); + Tensor tSrK = wg_mma0.partition_fragment_B(sK); + + Tensor tOrV = wg_mma1.partition_fragment_B(sV); + Tensor tOsP = wg_mma1.partition_fragment_A(sP); + Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); + + /* Dequant K */ + TiledMma0_dequant tiled_mma0_dequant; + TiledMma0_dequant_r2s tiled_mma0_dequant_r2s; + auto mma0_dequant = tiled_mma0_dequant.get_slice(thread_idx); + // Tensor tSrK_f = mma0_dequant.partition_fragment_B(sK); + Tensor tSrK_dequant = mma0_dequant.partition_fragment_B(sK); + Tensor tSrK_pack_tmp = mma0_dequant.partition_fragment_B(sK_pack); + Tensor tSrK_pack = make_fragment_like(tSrK_pack_tmp); + + // S2R + auto smem_tiled_copy_K_pack = make_tiled_copy_B(S2RCopyAtomKPack{}, tiled_mma0_dequant); + auto smem_thr_copy_K_pack = smem_tiled_copy_K_pack.get_slice(thread_idx); + Tensor tSsK = smem_thr_copy_K_pack.partition_S(sK); + // Tensor tSrK_view = smem_thr_copy_K_pack.retile_D(tSrK_f); + Tensor tSsK_pack = smem_thr_copy_K_pack.partition_S(sK_pack); + Tensor tSrK_pack_view = smem_thr_copy_K_pack.retile_D(tSrK_pack); + + // R2S + auto smem_tiled_copy_K_dequant = make_tiled_copy_B(R2SCopyAtomKDequant{}, tiled_mma0_dequant_r2s); + auto smem_thr_copy_K_dequant = smem_tiled_copy_K_dequant.get_slice(thread_idx); + // Tensor tSrK_r2s = smem_thr_copy_K_dequant.retile_S(tSrK_f); + Tensor tSsK_r2s = smem_thr_copy_K_dequant.partition_D(sK); + Tensor tSrK_dequant_r2s = smem_thr_copy_K_dequant.retile_S(tSrK_dequant); + + /* Dequant V */ + Tensor tSrV_dequant = mma0_dequant.partition_fragment_B(sV); + Tensor tSrV_pack_tmp = mma0_dequant.partition_fragment_B(sV_pack); + Tensor tSrV_pack = make_fragment_like(tSrV_pack_tmp); + + // S2R + auto smem_tiled_copy_V_pack = make_tiled_copy_B(S2RCopyAtomVPack{}, tiled_mma0_dequant); + auto smem_thr_copy_V_pack = smem_tiled_copy_V_pack.get_slice(thread_idx); + Tensor tSsV = smem_thr_copy_V_pack.partition_S(sV); + Tensor tSsV_pack = smem_thr_copy_V_pack.partition_S(sV_pack); + Tensor tSrV_pack_view = smem_thr_copy_V_pack.retile_D(tSrV_pack); + + // R2S + auto smem_tiled_copy_V_dequant = make_tiled_copy_B(R2SCopyAtomVDequant{}, tiled_mma0_dequant_r2s); + auto smem_thr_copy_V_dequant = smem_tiled_copy_V_dequant.get_slice(thread_idx); + // Tensor tSrV_r2s = smem_thr_copy_V_dequant.retile_S(tSrV_pack); + Tensor tSsV_r2s = smem_thr_copy_V_dequant.partition_D(sV); + Tensor tSrV_dequant_r2s = smem_thr_copy_V_dequant.retile_S(tSrV_dequant); + + // KParams + // cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::KParamsEmpty) /*id*/); + GmemTileCopyKParams gmem_tiled_copy_k_params; + auto gmem_thr_copy_k_params = gmem_tiled_copy_k_params.get_thread_slice(thread_idx); + Tensor tKgK_params = gmem_thr_copy_k_params.partition_S(gK_params); + Tensor tKsK_params = gmem_thr_copy_k_params.partition_D(sK_params); + cute::copy(gmem_tiled_copy_k_params, tKgK_params, tKsK_params); + cute::cp_async_fence(); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::KParamsEmpty) /*id*/); + + // VParams + GmemTileCopyVParams gmem_tiled_copy_v_params; + auto gmem_thr_copy_v_params = gmem_tiled_copy_v_params.get_thread_slice(thread_idx); + Tensor tVgV_params = gmem_thr_copy_v_params.partition_S(gV_params); + Tensor tVsV_params = gmem_thr_copy_v_params.partition_D(sV_params); + + using TensorParamsKC = decltype(make_tensor(make_shape(Int<4 * num_params>{}, Int{}, Int{}))); + using TensorParamsVG = decltype(make_tensor(make_shape(Int{}, Int{}))); + using TensorParamsG = decltype(make_tensor(make_shape(Int{}))); + TensorParamsKC tScales_k_c, tZeros_k_c; + TensorParamsVG tScales_v_c, tZeros_v_c; + TensorParamsG tScales_k_g, tZeros_k_g; + auto tScales_k_h2_c = cute::recast<__half2>(tScales_k_c); + auto tZeros_k_h2_c = cute::recast<__half2>(tZeros_k_c); + auto tScales_k_h2_g = cute::recast<__half2>(tScales_k_g); + auto tZeros_k_h2_g = cute::recast<__half2>(tZeros_k_g); + auto tScales_v_h2 = cute::recast<__half2>(tScales_v_c); + auto tZeros_v_h2 = cute::recast<__half2>(tZeros_v_c); + + /******/ + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + clear(tOrO); + + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + int n_block = n_block_max - 1; + + flash::Mask mask( + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, + params.qhead_per_khead_divmod + ); + + float softcap_val = params.softcap_val; + if constexpr (Has_softcap && Is_FP8) { + float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)]; + float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; + softcap_val *= q_descale * k_descale; + } + // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn + // -inf to e.g. -50.0, which can affect the attention softmax. + auto scoremod_premask_fn = [&](auto& tSrS) { + if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } + }; + + auto &barrier_Q = shared_storage.pipelines.barrier_Q; + barrier_Q.wait(work_idx % 2); + + // No intra-WG overlap + warp_scheduler_barrier_sync(); + auto fwd_step = [&](int const n_block, int const n_block_min, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { + static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; + static constexpr bool Check_inf = decltype(check_inf_type)::value; + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + + // cp_async_wait<0>(); + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::KParamsEmpty) /*id*/); + consumer_wait(pipeline_k, smem_pipe_read); + + /* Load Vparams */ + if (!is_first_iter_type) { + tVgV_params.data() = tVgV_params.data() + (-int(kBlockN * v_params_row_stride)); + } + cute::copy(gmem_tiled_copy_v_params, tVgV_params, tVsV_params); + cute::cp_async_fence(); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::KParamsEmpty) /*id*/); + + /* Dequant K */ + cute::copy(smem_tiled_copy_K_pack, tSsK_pack, tSrK_pack_view); + quant::load_params_Ktensor(tScales_k_h2_g, tZeros_k_h2_g, sK_params, thread_idx, num_params); + CUTE_UNROLL + for (int i = 0; i < size<2>(tSrK_pack); ++i) { + quant::dequantize_Ktensor(tSrK_pack, tSrK_dequant, tScales_k_h2_g, tZeros_k_h2_g, 4, group_size, i); + } + cute::copy(smem_tiled_copy_K_dequant, tSrK_dequant_r2s, tSsK_r2s); + + #if DEBUG + if (is_first_iter_type && threadIdx.x == 128 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + PRINT("sK_pack", sK_pack.layout()); PRINTTENSOR("sK_pack", sK_pack); + PRINT("tSrK_pack", tSrK_pack.layout()); PRINTTENSOR("tSrK_pack", tSrK_pack); + PRINT("tSrK", tSrK.layout()); // PRINTTENSOR("tSrK", tSrK); + PRINTTENSOR("sK", sK) + } + #endif + /******/ + + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + warp_scheduler_barrier_arrive(); + warpgroup_wait<0>(); + + pipeline_k.consumer_release(smem_pipe_read); // release K + + scoremod_premask_fn(tSrS); + mask_fn(tSrS, n_block); + Tensor scores_scale = softmax.template max_get_scale(tSrS); + softmax.template online_softmax(tSrS); + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } + + // cp_async_wait<0>(); + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::KParamsEmpty) /*id*/); + consumer_wait(pipeline_v, smem_pipe_read); + + /* Load Kparams */ + if (n_block > n_block_min) { + tKgK_params.data() = tKgK_params.data() + (-int(kBlockN_params * k_params_row_stride)); + cute::copy(gmem_tiled_copy_k_params, tKgK_params, tKsK_params); + cute::cp_async_fence(); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::KParamsEmpty) /*id*/); + } + + /* Dequant V */ + cute::copy(smem_tiled_copy_V_pack, tSsV_pack, tSrV_pack_view); + + CUTE_UNROLL + for (int i = 0; i < size<2>(tSrV_pack); ++i) { + quant::load_params_Vtensor(tScales_v_h2, tZeros_v_h2, sV_params, thread_idx, i, num_params); + quant::dequant_Kchannel_Vtensor(tSrV_pack(_,_,i,_0{}), tSrV_dequant(_,_,i,_0{}), tScales_v_h2(_,i), tZeros_v_h2(_,i), num_params); + } + cute::copy(smem_tiled_copy_V_dequant, tSrV_dequant_r2s, tSsV_r2s); + + #if DEBUG + if (is_first_iter_type && threadIdx.x == 128 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + PRINT("sV_pack", sV_pack.layout()); + PRINT("tSrV_pack", tSrV_pack.layout()); + PRINT("tOrV", tOrV.layout()); + } + #endif + /******/ + + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + pipeline_v.consumer_release(smem_pipe_read); // release V + ++smem_pipe_read; + }; + + auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; + fwd_step(n_block, n_block_min, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); + --n_block; + int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; + auto no_mask_fn = [](auto& tSrS, int n_block) { }; + #pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + fwd_step(n_block, n_block_min, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/); + } + + warp_scheduler_barrier_arrive(); + // Tell producers that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::KParamsEmpty) /*id*/); + float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; + Tensor scores_scale = softmax.finalize(v_descale); + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + + ++work_idx; + return true; + } +}; + +} // namespace flash diff --git a/hopper/src/include/mask.h b/hopper/src/include/mask.h new file mode 100644 index 0000000..02d0462 --- /dev/null +++ b/hopper/src/include/mask.h @@ -0,0 +1,157 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "cutlass/fast_math.h" // For cutlass::FastDivmod + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct Mask { + + static_assert(!(PackGQA && SwapAB), "Cannot be both PackGQA and SwapAB"); + + int const thread_idx; + int const seqlen_q, seqlen_k; + int const window_size_left, window_size_right, sink_token_length; + cutlass::FastDivmod const qhead_per_khead_divmod; + + CUTLASS_DEVICE + Mask(const int thread_idx, const int seqlen_q, const int seqlen_k, + const int window_size_left, const int window_size_right, const int sink_token_length, + cutlass::FastDivmod const &qhead_per_khead_divmod) + : thread_idx(thread_idx) + , seqlen_q(seqlen_q) + , seqlen_k(seqlen_k) + , window_size_left(window_size_left) + , window_size_right(window_size_right) + , sink_token_length(sink_token_length) + , qhead_per_khead_divmod(qhead_per_khead_divmod) + { + }; + + template + CUTLASS_DEVICE + void apply(Tensor &tSrS, const int m_block, const int n_block) const { + static_assert(!(Causal_mask && Local_mask), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + if (!Seqlenk_mask && !Causal_mask && !Local_mask) { return; } + + auto thread_mma = TiledMma{}.get_thread_slice(thread_idx); + auto thread0_mma = TiledMma{}.get_thread_slice(_0{}); + + static constexpr int Row = !SwapAB ? 0 : 1, Col = !SwapAB ? 1 : 0; + + Tensor cS = cute::make_identity_tensor(Shape, Int>{}); + Tensor tScS = thread_mma.partition_C(cS); + Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol(tScS.layout())); + Tensor t0ScS = thread0_mma.partition_C(cS); + Tensor t0ScS_rowcol = make_tensor(t0ScS.data(), flash::convert_layout_acc_rowcol(t0ScS.layout())); + // We want to use the col indices of thread0 to compare, since that is known at compile time. + // So we subtract the limit by the first col index of this thread (get(tScS_rowcol(_0{}, _0{}))) + int const thread_col_offset = get(tScS_rowcol(_0{}, _0{})); + int const seqlenk_col_limit = seqlen_k - n_block * kBlockN - thread_col_offset; + if constexpr (!Causal_mask && !Local_mask) { + if constexpr (Seqlenk_mask) { // Just masking based on col + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + if (int(get(t0ScS_rowcol(_0{}, n))) >= seqlenk_col_limit) { + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } + } else { // mask based on both row and col + if constexpr (!SwapAB) { + // If PackGQA, we split the work of compute divmod among threads in the same row + static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); + static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); + static_assert(!PackGQA || CUTE_STATIC_V(size<0>(tSrS_rowcol)) <= kMmaThreadsPerRow); + int mma_m_idx; + // Might get OOB but it's ok since we'll check it later + if constexpr (PackGQA) { + mma_m_idx = qhead_per_khead_divmod.divide(m_block * kBlockM + get(tScS_rowcol(thread_idx % kMmaThreadsPerRow, _0{}))); + } + int const causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q - thread_col_offset; + if constexpr (Causal_mask) { + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = !PackGQA + ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM + : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); + int const col_limit_right = !Seqlenk_mask + ? row_idx + causal_row_offset + : __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit); + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + if (int(get(t0ScS_rowcol(_0{}, n))) >= col_limit_right) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } else { + int const local_row_offset_right = causal_row_offset + window_size_right; + int const local_row_offset_left = causal_row_offset - 1 - window_size_left; + int const col_limit_sink = sink_token_length - n_block * kBlockN; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = !PackGQA + ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM + : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); + int const col_limit_right = !Seqlenk_mask + ? row_idx + local_row_offset_right + : __viaddmin_s32(row_idx, local_row_offset_right, seqlenk_col_limit); + int const col_limit_left = row_idx + local_row_offset_left; + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col_idx = int(get(t0ScS_rowcol(m, n))); + if (col_idx >= col_limit_right || (col_idx < col_limit_left && col_idx >= col_limit_sink)) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } + } else { + int const thread_row_offset = get(tScS_rowcol(_0{}, _0{})); + int const causal_row_offset = seqlenk_col_limit - seqlen_q + m_block * kBlockM + thread_row_offset; + if constexpr (Causal_mask) { + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col0 = int(get(t0ScS_rowcol(_0{}, n))); + // If col0 is beyond the column limit, we want to mask out the entire column, by setting + // row limit to be kBlockM. + int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + if (int(get(t0ScS_rowcol(m, _0{}))) < row_limit_top) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } else { + int const col_limit_sink = sink_token_length - n_block * kBlockN - thread_col_offset; + #pragma unroll + for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { + int const col0 = int(get(t0ScS_rowcol(_0{}, n))); + // If col0 is beyond the column limit, we want to mask out the entire column, by setting + // row limit to be kBlockM. + int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset - window_size_right; + int const row_limit_bot = col0 < col_limit_sink ? kBlockM : col0 - causal_row_offset + window_size_left; + #pragma unroll + for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { + int const row_idx = int(get(t0ScS_rowcol(m, _0{}))); + if (row_idx < row_limit_top || row_idx > row_limit_bot) { tSrS_rowcol(m, n) = -INFINITY; } + } + } + } + } + } + }; + +}; + +} // namespace flash diff --git a/hopper/src/include/named_barrier.hpp b/hopper/src/include/named_barrier.hpp new file mode 100644 index 0000000..3b0f50a --- /dev/null +++ b/hopper/src/include/named_barrier.hpp @@ -0,0 +1,78 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cutlass/arch/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// cutlass::arch::NamedBarrier::sync/arrive are only enabled Sm90 even though they work +// for Sm80 as well. We reimplement them here, enabled for both Sm90 and Sm80. + +CUTLASS_DEVICE +static void named_barrier_sync(uint32_t num_threads, uint32_t barrier_id_) { + static constexpr uint32_t ReservedNamedBarrierCount = static_cast(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier); + uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount; + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +} + +CUTLASS_DEVICE +static void named_barrier_sync(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) { + uint32_t barrier_id = static_cast(reserved_named_barriers); + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +} + +CUTLASS_DEVICE +static void named_barrier_arrive(uint32_t num_threads, uint32_t barrier_id_) { + static constexpr uint32_t ReservedNamedBarrierCount = static_cast(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier); + uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount; + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +} + +CUTLASS_DEVICE +static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) { + uint32_t barrier_id = static_cast(reserved_named_barriers); + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class FwdNamedBarriers { + QueryEmpty = 0, + ProducerWG = 1, + TileCountSmemEmpty = 2, + TileCountSmemFull = 3, + WarpSchedulerWG1 = 4, + WarpSchedulerWG2 = 5, + WarpSchedulerWG3 = 6, + AppendKV = 7, + QueryRotated = 8, + KParamsEmpty = 9, + VParamsEmpty = 10, +}; + +enum class BwdNamedBarriers { + KVEmpty = 0, + PdS = 1, + // This needs to match FwdNamedBarriers::TileCountSmemEmpty since TileScheduler uses it + TileCountSmemEmpty = 2, + TileCountSmemFull = 3, + dQEmptyWG1 = 4, + dQEmptyWG2 = 5, + dQEmptyWG3 = 6, + dQFullWG1 = 7, + dQFullWG2 = 8, + dQFullWG3 = 9, +}; + +} // flash diff --git a/hopper/src/include/pack_gqa.h b/hopper/src/include/pack_gqa.h new file mode 100644 index 0000000..160bf43 --- /dev/null +++ b/hopper/src/include/pack_gqa.h @@ -0,0 +1,255 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "cutlass/fast_math.h" // For cutlass::FastDivmod + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct PackGQAManager { + // We use CpAsync for Q, since TMA doesn't work there + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static constexpr int kGmemElemsPerStore = kGmemElemsPerLoad; + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // In the case of PackGQA, this reduces the number of times we need to call divmod. + static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + using GmemCopyAtomCpAsync = cute::Copy_Atom, Element>; + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyQCpAsync = decltype( + make_tiled_copy(GmemCopyAtomCpAsync{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + + // Was trying to have each WG loading Q to the rows in sQ that only that WG needs so that we only need + // to sync within each WG, but didn't seem to be any faster. + // using GmemLayoutAtomWG = Layout, Int, Int >, + // Stride, _128, _1>>; + // using GmemTiledCopyQCpAsyncWG = decltype( + // make_tiled_copy(GmemCopyAtomCpAsync{}, + // GmemLayoutAtomNew{}, + // Layout>>{})); // Val layout, 8 or 16 vals per load + + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + template + CUTLASS_DEVICE + static auto + compute_ptr(Tensor &tensor, TensorC const &tRows, + cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const m_block) { + // tensor of shape ((qhead_per_khead, seqlen_q)) + static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size(tRows)), NumThreadsPerRow); + using TensorType = typename Engine::value_type; + Tensor tPrPtr = make_tensor(Shape>{}); + #pragma unroll + for (int i = 0; i < NumPtrPerThread; ++i) { + int const row = i * NumThreads + get<0>(tRows(thread_idx % NumThreadsPerRow)); + int const idx = m_block * kBlockM + row; + int m_idx, h_idx; + m_idx = qhead_per_khead_divmod.divmod(h_idx, idx); + tPrPtr[i] = &tensor(make_coord(make_coord(h_idx, m_idx))); + } + return tPrPtr; + } + + + template + CUTLASS_DEVICE + static void + load_Q(TensormQ const &mQ, // ((qhead_per_khead, seqlen_q), headdim) + TensorsQ &sQ, // (kBlockM, kHeadDim) + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_q, int const m_block + ) + { + GmemTiledCopyQCpAsync gmem_tiled_copy_Q_cp_async; + // GmemTiledCopyQCpAsyncNew gmem_tiled_copy_Q_cp_async; + auto gmem_thr_copy_Q_cp_async = gmem_tiled_copy_Q_cp_async.get_thread_slice(thread_idx); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q_cp_async.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQsQ = gmem_thr_copy_Q_cp_async.partition_D(sQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + // Tensor tQcQ_ = gmem_thr_copy_Q_cp_async.partition_S(cute::flat_divide(cQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + // Tensor tQsQ_ = gmem_thr_copy_Q_cp_async.partition_D(cute::flat_divide(sQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + // Tensor tQcQ = group_modes<1, rank(tQcQ_) - 1>(tQcQ_); + // Tensor tQsQ = group_modes<1, rank(tQsQ_) - 1>(tQsQ_); + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < size<1>(mQ); } + + // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for Q. + // We split the work among threads loading the same row of Q, then __shfl_sync the pointers. + Tensor mQ_0 = mQ(_, _0{}); + Tensor tQcQ_row = tQcQ(_0{}, _, _0{}); + Tensor tPrQPtr = compute_ptr(mQ_0, tQcQ_row, qhead_per_khead_divmod, thread_idx, m_block); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int m = 0; m < size<1>(tQsQ); ++m) { + int idx = m_block * kBlockM + get<0>(tQcQ(_0{}, m, _0{})); + Element const* q_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrQPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + if (idx < seqlen_q * qhead_per_khead) { + // if (thread_idx == 0) { printf("m: %d, m_idx: %d, h_idx: %d, q_ptr = %p, q_ptr_og = %p\n", m, m_idx, h_idx, q_ptr, &mQ_copy(0, make_coord(h_idx, m_idx), 0));} + Tensor mQ_cur = make_tensor(make_gmem_ptr(q_ptr), Shape>{}); + Tensor mQ_cur_copy = cute::tiled_divide(mQ_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tQsQ); ++k) { + int ki = get<1>(tQcQ(_0{}, _0{}, k)) / kGmemElemsPerLoad; + // the "tiled_copy.with(tQpQ(k))"" will fill in zero for columns where tQpQ(k) is false + // TODO: check this + cute::copy(gmem_tiled_copy_Q_cp_async.with(tQpQ(k)), mQ_cur_copy(_, ki), tQsQ(_, m, k)); + } + } // Don't need to fill in 0s for sQ since we're not gonna write the output to gmem for those rows + } + }; + + template + CUTLASS_DEVICE + static void + store_LSE(TensormLSE &mLSE, // ((qhead_per_khead, seqlen_q)) + TensorsLSE const &tLSErLSE, // (kBlockM) split across threads according to tiled_mma + TiledMma tiled_mma, + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_o, int const m_block + ) + { + Tensor caccO = cute::make_identity_tensor(Shape, Int>{}); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + Tensor taccOcO_row = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()))(_, _0{}); + CUTE_STATIC_ASSERT_V(size(tLSErLSE) == size(taccOcO_row)); // MMA_M + + // If PackGQA, we split the work of compute divmod among threads in the same row + static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); + static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); + static_assert(CUTE_STATIC_V(size(tLSErLSE)) <= kMmaThreadsPerRow); + static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); + + Tensor tPrLSEPtr = compute_ptr(mLSE, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); + static_assert(CUTE_STATIC_V(size(tPrLSEPtr)) == 1); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int mi = 0; mi < size(tLSErLSE); ++mi) { + int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); + float* ptr_LSE_cur = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrLSEPtr[0]), mi % kMmaThreadsPerRow, kMmaThreadsPerRow)); + if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o * qhead_per_khead) { + *ptr_LSE_cur = tLSErLSE(mi); + } + } + }; + + template + CUTLASS_DEVICE + static void + store_O(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim) + TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to gmem_tiled_copy_O + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_o, int const m_block + ) + { + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor cO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < size<1>(mO); } + + // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O. + // We split the work among threads loading the same row of O, then __shfl_sync the pointers. + Tensor mO_0 = mO(_, _0{}); + Tensor tOcO_row = tOcO(_0{}, _, _0{}); + Tensor tPrOPtr = compute_ptr(mO_0, tOcO_row, qhead_per_khead_divmod, thread_idx, m_block); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int m = 0; m < size<1>(tOrO); ++m) { + int idx = m_block * kBlockM + get<0>(tOcO(_0{}, m, _0{})); + Element* o_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrOPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + if (idx < seqlen_o * qhead_per_khead) { + Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape>{}); + Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tOrO); ++k) { + int ki = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerStore; + if (tOpO(k)) { + cute::copy(gmem_tiled_copy_O, tOrO(_, m, k), mO_cur_copy(_, ki)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + static void + store_O_direct(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim) + TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to tiled_mma + TiledMma tiled_mma, + cutlass::FastDivmod const &qhead_per_khead_divmod, + int const thread_idx, int const seqlen_o, int const m_block + ) + { + static constexpr int kGmemElemsPerStoreDirect = 2; + cute::Copy_Atom, Element> gmem_copy_direct; + // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); + Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); + + Tensor caccO = cute::make_identity_tensor(Shape, Int>{}); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); + + // If PackGQA, we split the work of compute divmod among threads in the same row + static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); + static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); + static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); + + // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O. + // We split the work among threads loading the same row of O, then __shfl_sync the pointers. + Tensor mO_0 = mO(_, _0{}); + Tensor tPrOPtr = compute_ptr(mO_0, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); + static_assert(CUTE_STATIC_V(size(tPrOPtr)) == 1); + + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + #pragma unroll + for (int m = 0; m < size<1>(tOrO_copy); ++m) { + int row = m_block * kBlockM + get<0>(taccOcO_row(m)); + Element* o_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrOPtr[0]), m % kMmaThreadsPerRow, kMmaThreadsPerRow)); + if (row < seqlen_o * qhead_per_khead) { + Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape>{}); + Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tOrO_copy); ++k) { + int col = get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)); + if (col < size<1>(mO)) { + cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), mO_cur_copy(_, col / kGmemElemsPerStoreDirect)); + } + } + } + } + }; + +}; + +} // namespace flash diff --git a/hopper/src/include/paged_kv.h b/hopper/src/include/paged_kv.h new file mode 100644 index 0000000..0f710e5 --- /dev/null +++ b/hopper/src/include/paged_kv.h @@ -0,0 +1,301 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "cutlass/fast_math.h" // For cutlass::FastDivmod + +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct PagedKVManager { + // If KV_Same_Iter=false, then we do load_page_table(0), load_K(0), load_page_table(1), load_K(1), load_V(0), + // load_page_table(2), load_K(2), load_V(1), etc. + // So we need to compute the V pointers for the previous iteration. + + // LoadsPerRow_LB is the lower bound on number of loads per row in the K direction. This is useful for + // rotary where we want each thread to have at least 2 loads per row. + + // We use CpAsync for K and V if PagedKV, since TMA doesn't work there + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // In the case of PackGQA, this reduces the number of times we need to call divmod. + static_assert(kHeadDim % LoadsPerRow_LB == 0, "Headdim must be a multiple of LoadsPerRow_LB"); + static constexpr int kBytePerRow = kHeadDim / LoadsPerRow_LB * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + using GmemCopyAtomCpAsync = cute::Copy_Atom, Element>; + using GmemLayoutAtomKVCpAsync = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyKVCpAsync = decltype( + make_tiled_copy(GmemCopyAtomCpAsync{}, + GmemLayoutAtomKVCpAsync{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + using GmemTiledCopyKVStore = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomKVCpAsync{}, + Layout>>{})); // Val layout, 8 or 16 vals per load + + using ShapeKV = cute::Shape; // (seqlen, d, head, batch) + using StrideKV = cute::Stride; + using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) + using StridePageTable = cute::Stride; + + using TensorPageTable = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapePageTable{}, StridePageTable{})(int(0), _)); + using TensorKV = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeKV{}, StrideKV{})(_, _, int(0), _)); + using GmemThrCopyKVCpAsync = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0))); + using TensortKcK = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortKpK = decltype(make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{})); + + // For PagedKV, it's expensive the calculate the pointers to K and V for each page table entry, + // since those require int64_t arithmetic. We optimize by having threads split this work. + // Typically there are 8 threads loading per row (e.g. hdim 64 and 128), and there are 11 rows + // that each thread needs to load for the case of hdim 128 and kBlockN = 176. + // So each of those 8 threads will calculate the K_ptr and V_ptr for 11 / 8 = 2 rows. + // We then use __shfl_sync to broadcast the pointers to the other threads in the warp. + static constexpr int kPageEntryPerThread = cute::ceil_div(size<1>(TensortKcK{}), kGmemThreadsPerRow); + using TensorPageOffset = decltype(make_tensor>(Shape>{})); + using TensorKVPtr = decltype(make_tensor(Shape>{})); + + GmemTiledCopyKVCpAsync gmem_tiled_copy_kv; + cutlass::FastDivmod const &page_size_divmod; + int const thread_idx; + int const seqlen_k; + int const leftpad_k; + GmemThrCopyKVCpAsync const gmem_thr_copy_kv; + TensorPageTable mPageTable; + TensorKV mK_paged, mV_paged; + TensortKpK tKpK; + TensorPageOffset tPrPageOffset; + TensorKVPtr tPrVPtr; + + + CUTLASS_DEVICE + PagedKVManager(int const* const ptr_page_table, + ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable, + Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K, + Element* const ptr_V, StrideKV const &stride_V, + cutlass::FastDivmod const &page_size_divmod, + int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k + ) + : page_size_divmod(page_size_divmod) + , thread_idx(thread_idx) + , seqlen_k(seqlen_k) + , leftpad_k(leftpad_k) + , gmem_thr_copy_kv(gmem_tiled_copy_kv.get_thread_slice(thread_idx)) + + { + mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _); + mK_paged = make_tensor(make_gmem_ptr(ptr_K), shape_K, stride_K)(_, _, bidh, _); + mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_K, stride_V)(_, _, bidh, _); + tKpK = make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{}); + + Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); + #pragma unroll + for (int k = 0; k < size<1>(tKpK); ++k) { tKpK(_0{}, k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(shape_K); } + }; + + template + CUTLASS_DEVICE + void load_page_table(const int n_block) { + // The uncoalesced gmem load is intentional. This is so that each thread only loads the page table entries + // it needs, and we don't need any sync between warps. + // Assuming 8 threads per row, and 176 rows, then the rows from 0 to 175 are loaded by + // threads 0, 8, 16, ..., 120, 1, 9, ..., 121, 2, 10, ..., 122, etc. + #pragma unroll + for (int i = 0; i < kPageEntryPerThread; ++i) { + int const row = i * NumThreads + (thread_idx % kGmemThreadsPerRow) * (NumThreads / kGmemThreadsPerRow) + (thread_idx / kGmemThreadsPerRow); + int const row_idx = n_block * kBlockN + row; + int page_idx, page_offset; + page_idx = page_size_divmod.divmod(page_offset, row_idx + leftpad_k); + // Add the condition (i + 1) * NumThreads <= kBlockN since that is an upper bound of row + // and is known at compile time. It avoids branching when e.g., kBlockN = 176 and i = 0. + int const page = ((i + 1) * NumThreads <= kBlockN || row < kBlockN) && (!Seqlenk_mask || row_idx < seqlen_k) ? mPageTable[page_idx] : 0; + tPrPageOffset[i] = {page, page_offset}; + // if (cute::thread0()) { printf("row = %d, page_idx = %d, page_offset = %d, page = %d, leftpad_k = %d, seqlen_k = %d\n", row, page_idx, page_offset, page, leftpad_k, seqlen_k); } + } + if constexpr (First_iter && !KV_Same_Iter) { compute_V_ptr(); } + }; + + CUTLASS_DEVICE + TensorKVPtr compute_K_ptr() { + Tensor tPrKPtr = make_tensor(Shape>{}); + #pragma unroll + for (int i = 0; i < kPageEntryPerThread; ++i) { + auto [page, page_offset] = tPrPageOffset[i]; + tPrKPtr[i] = &mK_paged(page_offset, _0{}, page); + } + return tPrKPtr; + }; + + CUTLASS_DEVICE + void compute_V_ptr() { + #pragma unroll + for (int i = 0; i < kPageEntryPerThread; ++i) { + auto [page, page_offset] = tPrPageOffset[i]; + tPrVPtr[i] = &mV_paged(page_offset, _0{}, page); + } + }; + + template + CUTLASS_DEVICE + void load_K(const int n_block, TensorK &&sK) { + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0; + + Tensor tPrKPtr = compute_K_ptr(); + + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); + Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); + Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + + // We want to use the row indices of thread0 to compare, since that is known at compile time. + // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{}))) + int const seqlenk_row_limit = -int(get<0>(tKcK(_0{}, _0{}, _0{}))) + (EvenN + ? seqlen_k - n_block * kBlockN + : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k - n_block * kBlockN, kBlockN))); + #pragma unroll + for (int m = 0; m < size<1>(tKsK); ++m) { + bool const should_load = EvenN + ? (!Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit) + : get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + Element const* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape>{}); + if (should_load) { + #pragma unroll + for (int k = 0; k < size<2>(tKsK); ++k) { + int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k)), mK_paged_cur_copy(_, ki), tKsK(_, m, k)); + } + } // Don't need to clear out the rest of the smem since we'll mask out the scores anyway + } + }; + + template + CUTLASS_DEVICE + void load_V(const int n_block, TensorV &&sV) { + // Do we need bound check to make sure the row doesn't go above kBlockN + static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0; + + if constexpr (KV_Same_Iter) { compute_V_ptr(); } + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); + Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); + Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + + int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tVsV); ++m) { + // Faster to rely on the cp.async to clear smem that are out of bound, + // rather than calling cute::clear directly. + // We have to be careful not to write to smem past `kBlockN` if !EvenN. + // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to checked + if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKcK(_0{}, m, _0{})) < kBlockN) { + bool const should_load = !Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + Element const* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tVsV); ++k) { + int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); + } + } + } + if constexpr (!KV_Same_Iter) { compute_V_ptr(); } + }; + + template + CUTLASS_DEVICE + void store_K(const int n_block, TensorK &&tKrK) { + Tensor tPrKPtr = compute_K_ptr(); + // We're using the same partitioning as GmemTiledCopyKVCpAsync (used for loading) + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); + Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + + GmemTiledCopyKVStore gmem_tiled_copy_kv_store; + // We want to use the row indices of thread0 to compare, since that is known at compile time. + // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{}))) + // int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{})); + // if (threadIdx.x == 128) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_k = %d, seqlenk_row_limit = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_k, seqlenk_row_limit); } + #pragma unroll + for (int m = 0; m < size<1>(tKrK); ++m) { + bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape>{}); + if (should_load) { + #pragma unroll + for (int k = 0; k < size<2>(tKrK); ++k) { + int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (tKpK(_0{}, k)) { + cute::copy(gmem_tiled_copy_kv_store, tKrK(_, m, k), mK_paged_cur_copy(_, ki)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + void store_V(const int n_block, TensorV &&tVrV) { + if constexpr (KV_Same_Iter) { compute_V_ptr(); } + // Only for index calculation, since all the indices of thread 0 are known at compile time + auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); + Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); + Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + + GmemTiledCopyKVStore gmem_tiled_copy_kv_store; + int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{})); + #pragma unroll + for (int m = 0; m < size<1>(tVrV); ++m) { + bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + Element* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); + if (should_load) { + #pragma unroll + for (int k = 0; k < size<2>(tVrV); ++k) { + int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (tKpK(_0{}, k)) { + cute::copy(gmem_tiled_copy_kv_store, tVrV(_, m, k), mV_paged_cur_copy(_, ki)); + } + } + } + } + if constexpr (!KV_Same_Iter) { compute_V_ptr(); } + }; + + +}; + +} // namespace flash diff --git a/hopper/src/include/qpack.h b/hopper/src/include/qpack.h new file mode 100644 index 0000000..749a84b --- /dev/null +++ b/hopper/src/include/qpack.h @@ -0,0 +1,564 @@ +#pragma once + +#include +#include +#include "utils.h" + +namespace quant { + +using namespace cute; + +template +CUTE_DEVICE +void thread_reduce_(Tensor0 const& tensor, Tensor1& summary, Operator& op, const int num_params) { + const int pack_num = size<1>(tensor) / num_params; + + CUTE_UNROLL + for (int mi = 0; mi < size<0>(summary); ++mi) { + int col_start = (mi / 4) * pack_num; + summary(mi) = tensor(mi % 4, col_start); + + CUTE_UNROLL + for (int ni = col_start; ni < col_start + pack_num; ++ni) { + summary(mi) = op(summary(mi), tensor(mi % 4, ni)); + } + + } + +} + +template +__device__ __forceinline__ T warp_reduce(T val, Operator op) { + // Get the thread's position within its group of 4 + const int lane_id = threadIdx.x % 32; // Lane ID within warp + const int group_pos = lane_id % 4; // Position within group of 4 + + // Only reduce with threads that have the same position in their group of 4 + // Using butterfly pattern with xor + for (int mask = 16; mask > 0; mask >>= 1) { + T other = __shfl_xor_sync(0xffffffff, val, mask); + // Only combine if the other thread has the same group_pos + if ((lane_id ^ mask) < 32 && ((lane_id ^ mask) % 4 == group_pos)) { + val = op(val, other); + } + } + return val; +} + +template +CUTE_DEVICE +void allreduce_(Tensor0 &dst, Tensor1 &src, Tensor2 &reduce_tmp, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; + + #pragma unroll + for (int i = 0; i < size(dst); i++) { + // First do reduction within each group of 4 threads + float val = quant::warp_reduce(src(i), op); + // Write the result to shared memory for each group's leader + if (lane_id < 4) { + reduce_tmp(i,warp_id * 4 + lane_id) = val; + } + __syncthreads(); + + // First thread in the first group reads all values and reduces them + if (lane_id < 4) { + float final_val = reduce_tmp(i,0 + lane_id); + #pragma unroll + for (int w = 1; w < 4; w++) { // For 4 warps + final_val = op(final_val, reduce_tmp(i,w * 4 + lane_id)); + } + // Write back the final result + reduce_tmp(i, 0 + lane_id) = final_val; + } + __syncthreads(); + + // All threads read the final result + dst(i) = reduce_tmp(i,0 + lane_id % 4); + + } + + +} + +template +CUTE_DEVICE +void reduce_(Tensor const& tensor, Tensor& summary, Tensor2 &reduce_tmp, Operator& op, const int num_params) { + quant::thread_reduce_(tensor, summary, op, num_params); + quant::allreduce_(summary, summary, reduce_tmp, op); +} + +template +CUTE_DEVICE +void reduce_max(Tensor const& tensor, Tensor &max, Tensor2 &reduce_tmp, const int num_params) { + flash::MaxOp max_op; + quant::reduce_(tensor, max, reduce_tmp, max_op, num_params); // Use the existing reduce_q function +} + +template +CUTE_DEVICE +void reduce_min(Tensor const& tensor, Tensor &min, Tensor2 &reduce_tmp, const int num_params) { + flash::MinOp min_op; + quant::reduce_(tensor, min, reduce_tmp, min_op, num_params); // Use the existing reduce_q function +} + +template +struct qpack_kc_vt; + +template +struct qpack_kc_vt<2, Tensor1, Tensor2, Tensor3, Tensor4, Tensor5> { + static constexpr int num_bits = 2; // Add this line + CUTE_DEVICE static + void apply(Tensor1 &src, Tensor2 &dst, Tensor3 &scales_k, Tensor4 &zeros_k, Tensor5 &reduce_tmp, const int num_params) { + const float max_val = float((1 << num_bits) - 1); + const int pack_num = 4 / (num_params / 2); // TODO: check 4 + const int num_params_2 = size<1>(src) == 4 ? num_params / 2 : num_params; // TODO: change name? seems hard code? + const int channel_stride = size<0>(src); + + // Declare per-channel tensors + using TensorChannel = decltype(make_fragment_like(scales_k(_, 0))); + TensorChannel channel_max, channel_min, channel_range, channel_scales_inv, channel_zeros; + + CUTE_UNROLL + for (int k = 0; k < size<2>(src); ++k) { + // Perform per-channel max and min reductions + quant::reduce_max(src(_, _, k), channel_max, reduce_tmp, num_params_2); + quant::reduce_min(src(_, _, k), channel_min, reduce_tmp, num_params_2); + + // Compute per-channel scale inverses and zeros + CUTE_UNROLL + for (int i = 0; i < size(channel_max); ++i) { + float max_i = float(channel_max(i)); + float min_i = float(channel_min(i)); + float range = max_i - min_i; + // Avoid division by zero + float scale_inv = (range > 0.0f) ? (max_val / range) : 0.0f; + channel_scales_inv(i) = scale_inv; + channel_zeros(i) = min_i; + // Store scales and zeros + scales_k(i, k) = scale_inv == 0 ? 0.0f : 1.0f / scale_inv; // Store actual scale + zeros_k(i, k) = min_i; + } + + // Quantize and pack the tensor + CUTE_UNROLL + for (int i = 0; i < size<0>(src); ++i) { + + CUTE_UNROLL + for (int jj = 0; jj < size<1>(src); jj += 8) { + // float val0 = float(src(i, jj, k)); + // float val1 = float(src(i, jj + 1, k)); + // float val2 = float(src(i, jj + 2, k)); + // float val3 = float(src(i, jj + 3, k)); + // float val4 = float(src(i, jj + 4, k)); + // float val5 = float(src(i, jj + 5, k)); + // float val6 = float(src(i, jj + 6, k)); + // float val7 = float(src(i, jj + 7, k)); + + // Load 4 values and convert to float + float val0 = float(src(i, jj, k)) - channel_zeros(i + (jj ) / pack_num * channel_stride); + float val1 = float(src(i, jj + 1, k)) - channel_zeros(i + (jj + 1) / pack_num * channel_stride); + float val2 = float(src(i, jj + 2, k)) - channel_zeros(i + (jj + 2) / pack_num * channel_stride); + float val3 = float(src(i, jj + 3, k)) - channel_zeros(i + (jj + 3) / pack_num * channel_stride); + float val4 = float(src(i, jj + 4, k)) - channel_zeros(i + (jj + 4) / pack_num * channel_stride); + float val5 = float(src(i, jj + 5, k)) - channel_zeros(i + (jj + 5) / pack_num * channel_stride); + float val6 = float(src(i, jj + 6, k)) - channel_zeros(i + (jj + 6) / pack_num * channel_stride); + float val7 = float(src(i, jj + 7, k)) - channel_zeros(i + (jj + 7) / pack_num * channel_stride); + + // Apply scale inverses + val0 *= channel_scales_inv(i + (jj ) / pack_num * channel_stride); + val1 *= channel_scales_inv(i + (jj + 1) / pack_num * channel_stride); + val2 *= channel_scales_inv(i + (jj + 2) / pack_num * channel_stride); + val3 *= channel_scales_inv(i + (jj + 3) / pack_num * channel_stride); + val4 *= channel_scales_inv(i + (jj + 4) / pack_num * channel_stride); + val5 *= channel_scales_inv(i + (jj + 5) / pack_num * channel_stride); + val6 *= channel_scales_inv(i + (jj + 6) / pack_num * channel_stride); + val7 *= channel_scales_inv(i + (jj + 7) / pack_num * channel_stride); + + // Round and clamp the values + val0 = fminf(fmaxf(roundf(val0), 0.0f), max_val); + val1 = fminf(fmaxf(roundf(val1), 0.0f), max_val); + val2 = fminf(fmaxf(roundf(val2), 0.0f), max_val); + val3 = fminf(fmaxf(roundf(val3), 0.0f), max_val); + val4 = fminf(fmaxf(roundf(val4), 0.0f), max_val); + val5 = fminf(fmaxf(roundf(val5), 0.0f), max_val); + val6 = fminf(fmaxf(roundf(val6), 0.0f), max_val); + val7 = fminf(fmaxf(roundf(val7), 0.0f), max_val); + + // Pack 8 values (2-bit each) into a 16-bit integer + uint16_t packed = 0; + packed |= (static_cast(static_cast(val7)) & 0x3); // 2 bits + packed <<= 2; + packed |= (static_cast(static_cast(val6)) & 0x3); + packed <<= 2; + packed |= (static_cast(static_cast(val5)) & 0x3); + packed <<= 2; + packed |= (static_cast(static_cast(val4)) & 0x3); + packed <<= 2; + packed |= (static_cast(static_cast(val3)) & 0x3); + packed <<= 2; + packed |= (static_cast(static_cast(val2)) & 0x3); + packed <<= 2; + packed |= (static_cast(static_cast(val1)) & 0x3); + packed <<= 2; + packed |= (static_cast(static_cast(val0)) & 0x3); + + // Store the packed value + dst(i, jj / 8, k) = packed; + } + } + } + + } + + +}; + +template +struct qpack_kc_vt<4, Tensor1, Tensor2, Tensor3, Tensor4, Tensor5> { + static constexpr int num_bits = 4; // Add this line + CUTE_DEVICE static + void apply(Tensor1 &src, Tensor2 &dst, Tensor3 &scales_k, Tensor4 &zeros_k, Tensor5 &reduce_tmp, const int num_params) { + const float max_val = float((1 << num_bits) - 1); + const int pack_num = size<1>(src) / num_params; + const int channel_stride = size<0>(src); + + // Declare per-channel tensors + using TensorChannel = decltype(make_fragment_like(scales_k(_, 0))); + TensorChannel channel_max, channel_min, channel_range, channel_scales_inv, channel_zeros; + + + CUTE_UNROLL + for (int k = 0; k < size<2>(src); ++k) { + // Perform per-channel max and min reductions + quant::reduce_max(src(_, _, k), channel_max, reduce_tmp, num_params); + quant::reduce_min(src(_, _, k), channel_min, reduce_tmp, num_params); + + // Compute per-channel scale inverses and zeros + CUTE_UNROLL + for (int i = 0; i < size(channel_max); ++i) { + float max_i = float(channel_max(i)); + float min_i = float(channel_min(i)); + float range = max_i - min_i; + // Avoid division by zero + float scale_inv = (range > 0.0f) ? (max_val / range) : 0.0f; + channel_scales_inv(i) = scale_inv; + channel_zeros(i) = min_i; + // Store scales and zeros + scales_k(i, k) = scale_inv == 0 ? 0.0f : 1.0f / scale_inv; // Store actual scale + zeros_k(i, k) = min_i; + } + + // Quantize and pack the tensor + CUTE_UNROLL + for (int i = 0; i < size<0>(src); ++i) { + + CUTE_UNROLL + for (int jj = 0; jj < size<1>(src); jj += 4) { + // float val0 = float(src(i, jj, k)); + // float val1 = float(src(i, jj + 1, k)); + // float val2 = float(src(i, jj + 2, k)); + // float val3 = float(src(i, jj + 3, k)); + + // Load 4 values and convert to float + float val0 = float(src(i, jj, k)) - channel_zeros(i + (jj ) / pack_num * channel_stride); + float val1 = float(src(i, jj + 1, k)) - channel_zeros(i + (jj + 1) / pack_num * channel_stride); + float val2 = float(src(i, jj + 2, k)) - channel_zeros(i + (jj + 2) / pack_num * channel_stride); + float val3 = float(src(i, jj + 3, k)) - channel_zeros(i + (jj + 3) / pack_num * channel_stride); + + // Apply scale inverses + val0 *= channel_scales_inv(i + (jj ) / pack_num * channel_stride); + val1 *= channel_scales_inv(i + (jj + 1) / pack_num * channel_stride); + val2 *= channel_scales_inv(i + (jj + 2) / pack_num * channel_stride); + val3 *= channel_scales_inv(i + (jj + 3) / pack_num * channel_stride); + + // Round and clamp the values + val0 = fminf(fmaxf(roundf(val0), 0.0f), max_val); + val1 = fminf(fmaxf(roundf(val1), 0.0f), max_val); + val2 = fminf(fmaxf(roundf(val2), 0.0f), max_val); + val3 = fminf(fmaxf(roundf(val3), 0.0f), max_val); + + // Pack the 4 quantized values into a 16-bit integer + uint16_t packed = 0; + packed |= (static_cast(static_cast(val3)) & 0xF); + packed <<= 4; + packed |= (static_cast(static_cast(val2)) & 0xF); + packed <<= 4; + packed |= (static_cast(static_cast(val1)) & 0xF); + packed <<= 4; + packed |= (static_cast(static_cast(val0)) & 0xF); + + // Store the packed value + dst(i, jj / 4, k) = packed; + } + } + } + + } +}; + +template +CUTE_DEVICE +void qpack_Kchannel_Vtensor(Tensor1 &src, Tensor2 &dst, + Tensor3 &scales_k, Tensor4 &zeros_k, Tensor5 &reduce_tmp, + const int num_params = 1) { + + qpack_kc_vt::apply(src, dst, scales_k, zeros_k, reduce_tmp, num_params); + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_DEVICE +void quad_allreduce_g(TensorParamsG0 &dst, Tensor1 &src, Operator &op, int k, int num_params) { + CUTE_UNROLL + for (int i = k * num_params; i < (k + 1) * num_params; i++) { + + // Calculate which group of 4 this thread belongs to + const int group_id = threadIdx.x / 4; + const int group_base = group_id * 4; + + // Start with the value from the first thread in our group + auto val = __shfl_sync(uint32_t(-1), src(i), group_base); + + // Reduce with the other 3 threads in our group + #pragma unroll + for (int offset = 1; offset < 4; offset++) { + val = op(val, __shfl_sync(uint32_t(-1), src(i), group_base + offset)); + } + + // Broadcast the final result back to all threads in the group + dst(i) = val; + + } +} + +template +CUTE_DEVICE +void thread_reduce_g(Tensor0 const& tensor, TensorParamsG0& summary, Operator& op, int k, int num_params) { + CUTE_UNROLL + for (int i = k * num_params, j = 0; i < (k + 1) * num_params; i++, j++) { + int ii = size<1>(tensor) / num_params; + summary(i) = tensor(0, j * ii); + + CUTE_UNROLL + for (int mi = 0; mi < size<0>(tensor); ++mi) { + CUTE_UNROLL + for (int ni = j * ii; ni < (j + 1) * ii; ++ni) { + summary(i) = op(summary(i), tensor(mi, ni)); + } + } + } +} + +template +CUTE_DEVICE +void reduce_g(Tensor const& tensor, TensorParamsG0& summary, Operator& op, int k, int num_params) { + quant::thread_reduce_g(tensor, summary, op, k, num_params); + quant::quad_allreduce_g(summary, summary, op, k, num_params); +} + +template +CUTE_DEVICE +void reduce_max_g(Tensor const& tensor, TensorParamsG0 &max, int k, int num_params) { + flash::MaxOp max_op; + quant::reduce_g(tensor, max, max_op, k, num_params); // Use the existing reduce_q function +} + +template +CUTE_DEVICE +void reduce_min_g(Tensor const& tensor, TensorParamsG0 &min, int k, int num_params) { + flash::MinOp min_op; + quant::reduce_g(tensor, min, min_op, k, num_params); // Use the existing reduce_q function +} + +template +CUTE_DEVICE +void quant_Ktensor(Tensor1 &src, Tensor2 &dst, + TensorParamsG1 &scales_k_g, TensorParamsG2 &zeros_k_g, + const int num_params) { + + const int num_bits = 4; + + const float max_val = float((1 << num_bits) - 1); + // const int num_params = 128 / group_size; + const int ki = size<2>(src) / num_params; + + // Declare per-channel tensors + using TensorChannel = decltype(make_fragment_like(scales_k_g)); + TensorChannel channel_max, channel_min, channel_range, channel_scales_inv, channel_zeros; + + CUTE_UNROLL + for (int k = 0; k < size<1>(src); ++k) { + quant::reduce_max_g(src(_, k, _), channel_max, k, num_params); // TODO:check 128 + quant::reduce_min_g(src(_, k, _), channel_min, k, num_params); + } + + // Compute per-channel scale inverses and zeros + CUTE_UNROLL + for (int i = 0; i < size(channel_max); ++i) { + float max_i = float(channel_max(i)); + float min_i = float(channel_min(i)); + float range = max_i - min_i; + // Avoid division by zero + float scale_inv = (range > 0.0f) ? (max_val / range) : 0.0f; + channel_scales_inv(i) = scale_inv; + channel_zeros(i) = min_i; + // Store scales and zeros + scales_k_g(i) = scale_inv == 0 ? 0.0f : 1.0f / scale_inv; // Store actual scale + zeros_k_g(i) = min_i; + } + + // Pack the tensor + CUTE_UNROLL + for (int k = 0; k < size<2>(src); ++k) { + + CUTE_UNROLL + for (int i = 0; i < size<0>(src); ++i) { + + CUTE_UNROLL + for (int jj = 0; jj < size<1>(src); jj += 4) { + float zero0 = float(channel_zeros(k / ki + jj + 0 * num_params)); + float zero1 = float(channel_zeros(k / ki + jj + 1 * num_params)); + float zero2 = float(channel_zeros(k / ki + jj + 2 * num_params)); + float zero3 = float(channel_zeros(k / ki + jj + 3 * num_params)); + + float scale_inv0 = float(channel_scales_inv(k / ki + jj + 0 * num_params)); + float scale_inv1 = float(channel_scales_inv(k / ki + jj + 1 * num_params)); + float scale_inv2 = float(channel_scales_inv(k / ki + jj + 2 * num_params)); + float scale_inv3 = float(channel_scales_inv(k / ki + jj + 3 * num_params)); + + // float val0 = float(src(i, jj, k)); + // float val1 = float(src(i, jj + 1, k)); + // float val2 = float(src(i, jj + 2, k)); + // float val3 = float(src(i, jj + 3, k)); + + float val0 = float(src(i, jj, k)) - zero0; + float val1 = float(src(i, jj + 1, k)) - zero1; + float val2 = float(src(i, jj + 2, k)) - zero2; + float val3 = float(src(i, jj + 3, k)) - zero3; + + val0 *= scale_inv0; + val1 *= scale_inv1; + val2 *= scale_inv2; + val3 *= scale_inv3; + + // Round and clamp the values + val0 = fminf(fmaxf(roundf(val0), 0.0f), max_val); + val1 = fminf(fmaxf(roundf(val1), 0.0f), max_val); + val2 = fminf(fmaxf(roundf(val2), 0.0f), max_val); + val3 = fminf(fmaxf(roundf(val3), 0.0f), max_val); + + // Pack the 4 quantized values into a 16-bit integer + uint16_t packed = 0; + packed |= (static_cast(static_cast(val3)) & 0xF); + packed <<= 4; + packed |= (static_cast(static_cast(val2)) & 0xF); + packed <<= 4; + packed |= (static_cast(static_cast(val1)) & 0xF); + packed <<= 4; + packed |= (static_cast(static_cast(val0)) & 0xF); + + // Store the packed value + dst(i, jj / 4, k) = packed; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_DEVICE +void pack_Ktensor_store(TiledCopyRS smem_tiled_copy, Tensor0 &src_r2s, Tensor1 &dst_r2s, + TiledCopySG gmem_tiled_copy, Tensor2 &src_s2g, Tensor3 &dst_g2s, + Tensor4 &scales, Tensor5 &zeros, Tensor6 ¶ms, + const int num_params) { + // copy from register to shared memory + cute::copy(smem_tiled_copy, src_r2s, dst_r2s); + __syncthreads(); + + // copy from shared memory to global memory + cute::copy(gmem_tiled_copy, src_s2g, dst_g2s); + __syncthreads(); + + // copy params from register to global memory + CUTE_UNROLL + for (int j = 0; j < size<0>(scales); ++j) { + params(0 + 32 * (j / num_params) + threadIdx.x / 4, j % num_params) = scales(j); + params(64 + 32 * (j / num_params) + threadIdx.x / 4, j % num_params) = zeros(j); + } + __syncthreads(); +} + +template +CUTE_DEVICE +void pack_Kchannel_store(TiledCopyRS smem_tiled_copy, Tensor0 &src_r2s, Tensor1 &dst_r2s, + TiledCopySG gmem_tiled_copy, Tensor2 &src_s2g, Tensor3 &dst_g2s, + Tensor4 &scales, Tensor5 &zeros, Tensor6 ¶ms, + const int num_params) { + // copy from register to shared memory + cute::copy(smem_tiled_copy, src_r2s, dst_r2s); + __syncthreads(); + + // copy from shared memory to global memory + cute::copy(gmem_tiled_copy, src_s2g, dst_g2s); + __syncthreads(); + + // copy params from register to global memory + CUTE_UNROLL + for (int i = 0; i < size<1>(scales); ++i) { + CUTE_UNROLL + for (int j = 0; j < size<0>(scales); ++j) { + params(j % num_params, 0 + 8 * i + 4 * (j / num_params) + threadIdx.x % 4) = scales(j, i); + params(j % num_params, 64 + 8 * i + 4 * (j / num_params) + threadIdx.x % 4) = zeros(j, i); + } + } + __syncthreads(); +} + +template +CUTE_DEVICE +void pack_Vtensor_store(TiledCopyRS smem_tiled_copy, Tensor0 &src_r2s, Tensor1 &dst_r2s, + TiledCopySG gmem_tiled_copy, Tensor2 &src_s2g, Tensor3 &dst_g2s, + Tensor4 &scales, Tensor5 &zeros, Tensor6 ¶ms, + const int num_params) { + if (kHeadDim == 128 && num_bits == 2) { + if (threadIdx.x < 64) { + cute::copy(smem_tiled_copy, src_r2s, dst_r2s); + } + } else { + cute::copy(smem_tiled_copy, src_r2s, dst_r2s); + } + __syncthreads(); + + // copy from shared memory to global memory + cute::copy(gmem_tiled_copy, src_s2g, dst_g2s); + __syncthreads(); + + // copy params from register to global memory + const int num_params_2 = num_bits == 2 ? num_params / 2 : num_params; + CUTE_UNROLL + for (int i = 0; i < size<1>(scales); ++i) { + CUTE_UNROLL + for (int j = 0; j < size<0>(scales); ++j) { + params(128 * (i / 8) + 0 + 8 * (i % 8) + 4 * (j / num_params_2) + threadIdx.x % 4, j % num_params_2) = scales(j, i); + params(128 * (i / 8) + 64 + 8 * (i % 8) + 4 * (j / num_params_2) + threadIdx.x % 4, j % num_params_2) = zeros(j, i); + } + } + __syncthreads(); +} + +} // namespace quant + + diff --git a/hopper/src/include/rotary.h b/hopper/src/include/rotary.h new file mode 100644 index 0000000..5e30456 --- /dev/null +++ b/hopper/src/include/rotary.h @@ -0,0 +1,489 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void +apply_rotary_interleaved(Tensor &rK, + Tensor const &rCos, + Tensor const &rSin) { + CUTE_STATIC_ASSERT_V(rank(rK) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rCos) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rSin) == _1{}); + CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin)); + static_assert(decltype(size<0>(rK))::value == decltype(size<0>(rCos))::value * 2); + static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor K_fp32 = make_tensor_like(rK); + convert_type_out(rK, K_fp32); + Tensor cos_fp32 = make_tensor_like(rCos); + convert_type_out(rCos, cos_fp32); + Tensor sin_fp32 = make_tensor_like(rSin); + convert_type_out(rSin, sin_fp32); + #pragma unroll + for (int i = 0; i < size<0>(K_fp32) / 2; ++i) { + float real = K_fp32[2 * i] * cos_fp32[i] - K_fp32[2 * i + 1] * sin_fp32[i]; + float imag = K_fp32[2 * i] * sin_fp32[i] + K_fp32[2 * i + 1] * cos_fp32[i]; + K_fp32[2 * i] = real; + K_fp32[2 * i + 1] = imag; + } + convert_type_out(K_fp32, rK); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void +apply_rotary_contiguous(Tensor &rK_left, + Tensor &rK_right, + Tensor const &rCos, + Tensor const &rSin) { + CUTE_STATIC_ASSERT_V(rank(rK_left) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rK_right) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rCos) == _1{}); + CUTE_STATIC_ASSERT_V(rank(rSin) == _1{}); + CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rK_right)); + CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rCos)); + CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin)); + static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor K_left_fp32 = make_tensor_like(rK_left); + convert_type_out(rK_left, K_left_fp32); + Tensor K_right_fp32 = make_tensor_like(rK_right); + convert_type_out(rK_right, K_right_fp32); + Tensor cos_fp32 = make_tensor_like(rCos); + convert_type_out(rCos, cos_fp32); + Tensor sin_fp32 = make_tensor_like(rSin); + convert_type_out(rSin, sin_fp32); + #pragma unroll + for (int i = 0; i < size<0>(K_left_fp32); ++i) { + float real = K_left_fp32[i] * cos_fp32[i] - K_right_fp32[i] * sin_fp32[i]; + float imag = K_left_fp32[i] * sin_fp32[i] + K_right_fp32[i] * cos_fp32[i]; + K_left_fp32[i] = real; + K_right_fp32[i] = imag; + } + convert_type_out(K_left_fp32, rK_left); + convert_type_out(K_right_fp32, rK_right); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Rotary { + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each + // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. + // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved + // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will + // load twice from the same row. + static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); + static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); + static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; + static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); + // We assume threads loading the same row are in the same warp. + static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); + + using LayoutAtom = Layout, Int>, + Stride, _1>>; + using TiledCopyQK = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + LayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + using GmemTiledCopyRotary = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + LayoutAtom{}, + Layout>>{})); // Val layout, 4 or 8 vals per store + using GmemTiledCopyRotaryCont = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + LayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) + using StrideRotary = cute::Stride; + + using GmemThrCopyRotary = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0))); + using GmemThrCopyRotaryCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0))); + using TensortRcR = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortRpR = decltype(make_tensor(make_shape(size<2>(TensortRcR{})))); + using TensortRcRCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortRpRCont = decltype(make_tensor(make_shape(size<2>(TensortRcRCont{})))); + using TensormR = decltype(make_tensor( + make_gmem_ptr((Element const*)nullptr), + ShapeRotary{}, + make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}))); + using TensortRgR = decltype( + GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_S(make_tensor( + make_gmem_ptr((Element const*)nullptr), + make_shape(Int{}, Int{}, int(0)), + make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}, cute::conditional_return(_0{}, int64_t(0)))))); + using TensortRgRCont = decltype( + GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_S(make_tensor( + make_gmem_ptr((Element const*)nullptr), + make_shape(Int{}, Int{}, int(0)), + make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}, cute::conditional_return(_0{}, int64_t(0)))))); + + GmemTiledCopyRotary gmem_tiled_copy_rotary; + GmemTiledCopyRotaryCont gmem_tiled_copy_rotary_cont; + bool const is_rotary_interleaved; + int const rotary_dim; + int const thread_idx; + int const max_seqlen; + GmemThrCopyRotary const gmem_thr_copy_rotary; + GmemThrCopyRotaryCont const gmem_thr_copy_rotary_cont; + TensortRpR tRpR; + TensortRpRCont tRpRCont; + TensormR mCos, mSin; + TensortRgR tRgCos, tRgSin; + TensortRgRCont tRgCosCont, tRgSinCont; + + CUTLASS_DEVICE + Rotary(Element const* const ptr_rotary_cos, ShapeRotary const &shape_rotary, StrideRotary const &stride_rotary_cos_, + Element const* const ptr_rotary_sin, StrideRotary const &stride_rotary_sin_, + bool const is_rotary_interleaved, int const thread_idx, int const max_seqlen, int const start_idx) + : is_rotary_interleaved(is_rotary_interleaved) + , rotary_dim(get<1>(shape_rotary) * 2) + , thread_idx(thread_idx) + , max_seqlen(max_seqlen) + , gmem_thr_copy_rotary(gmem_tiled_copy_rotary.get_thread_slice(thread_idx)) + , gmem_thr_copy_rotary_cont(gmem_tiled_copy_rotary_cont.get_thread_slice(thread_idx)) + + { + auto stride_rotary_cos = make_stride(cute::conditional_return(get<0>(stride_rotary_cos_), _0{}), get<1>(stride_rotary_cos_)); + auto stride_rotary_sin = make_stride(cute::conditional_return(get<0>(stride_rotary_sin_), _0{}), get<1>(stride_rotary_sin_)); + mCos = make_tensor(make_gmem_ptr(ptr_rotary_cos + start_idx * get<0>(stride_rotary_cos_)), shape_rotary, stride_rotary_cos); + mSin = make_tensor(make_gmem_ptr(ptr_rotary_sin + start_idx * get<0>(stride_rotary_sin_)), shape_rotary, stride_rotary_sin); + Tensor gCos = local_tile(mCos, Shape, Int>{}, make_coord(_, _0{})); // (MN, K / 2, _) + Tensor gSin = local_tile(mSin, Shape, Int>{}, make_coord(_, _0{})); // (MN, K / 2, _) + tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCos); + tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSin); + Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) + Tensor tRcR = gmem_thr_copy_rotary.partition_D(cR); + tRpR = make_tensor(make_shape(size<2>(tRcR))); + #pragma unroll + for (int k = 0; k < size(tRpR); ++k) { tRpR(k) = get<1>(tRcR(_0{}, _0{}, k)) < get<1>(shape_rotary); } + Tensor tRcRCont = gmem_thr_copy_rotary_cont.partition_D(cR); + tRpRCont = make_tensor(make_shape(size<2>(tRcRCont))); + #pragma unroll + for (int k = 0; k < size(tRpRCont); ++k) { tRpRCont(k) = get<1>(tRcRCont(_0{}, _0{}, k)) < get<1>(shape_rotary); } + }; + + template + CUTLASS_DEVICE + auto load_cos_sin(int const block) { + using GmemTiledCopyRo = std::conditional_t; + auto gmem_thr_copy_ro = cute::conditional_return(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont); + Tensor tRpRCur = cute::conditional_return(tRpR, tRpRCont); + Tensor tRgCosCur = cute::conditional_return(tRgCos, tRgCosCont)(_, _, _, block); + Tensor tRgSinCur = cute::conditional_return(tRgSin, tRgSinCont)(_, _, _, block); + // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way + Tensor tRrCos = make_tensor_like(tRgCosCur); + Tensor tRrSin = make_tensor_like(tRgSinCur); + Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) + Tensor tRcR = gmem_thr_copy_ro.partition_D(cR); + // If FixedPosition, only copy the first row as we only need the cos/sin for position cache_seqlens + #pragma unroll + for (int m = 0; m < (!FixedPosition ? size<1>(tRrCos) : 1); ++m) { + if (get<0>(tRcR(_0{}, m, _0{})) < std::min(max_seqlen - block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tRrCos); ++k) { + if (tRpRCur(k)) { + cute::copy(GmemTiledCopyRo{}, tRgCosCur(_, m, k), tRrCos(_, m, k)); + cute::copy(GmemTiledCopyRo{}, tRgSinCur(_, m, k), tRrSin(_, m, k)); + } + } + } + } + return cute::make_tuple(tRrCos, tRrSin);; + } + + template + CUTLASS_DEVICE + auto load_cos_sin_packgqa(int const block, cutlass::FastDivmod const &qhead_per_khead_divmod) { + static constexpr int kGmemElemsPerLoadCur = kInterleaved ? kGmemElemsPerLoad / 2 : kGmemElemsPerLoad; + using GmemTiledCopyRo = std::conditional_t; + auto gmem_thr_copy_ro = cute::conditional_return(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont); + Tensor tRpRCur = cute::conditional_return(tRpR, tRpRCont); + // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way + Tensor tRrCos = make_tensor_like(cute::conditional_return(tRgCos, tRgCosCont)(_, _, _, _0{})); + Tensor tRrSin = make_tensor_like(cute::conditional_return(tRgSin, tRgSinCont)(_, _, _, _0{})); + int const qhead_per_khead = qhead_per_khead_divmod.divisor; + Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) + Tensor tRcR = gmem_thr_copy_ro.partition_D(cR); + + // The main bottleneck here is actually instruction cache misses. + + // Similar to PagedKV, it's expensive to compute the pointers. + // We split the work among threads loading the same row, then __shfl_sync the pointers. + static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size<1>(tRrCos)), kGmemThreadsPerRow); + Tensor tPrCosPtr = make_tensor(Shape>{}); + Tensor tPrSinPtr = make_tensor(Shape>{}); + #pragma unroll + for (int i = 0; i < NumPtrPerThread; ++i) { + int const row = i * NumThreads + get<0>(tRcR(_0{}, thread_idx % kGmemThreadsPerRow, _0{})); + int const idx = block * kBlockMN + row; + int row_actual = qhead_per_khead_divmod.divide(idx); + tPrCosPtr[i] = &mCos(row_actual, _0{}); + tPrSinPtr[i] = &mSin(row_actual, _0{}); + } + + #pragma unroll + for (int m = 0; m < (!FixedPosition ? size<1>(tRgCos) : 1); ++m) { + int const idx = block * kBlockMN + get<0>(tRcR(_0{}, m, _0{})); + Element const* cos_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrCosPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + Element const* sin_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrSinPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); + if (idx < max_seqlen * qhead_per_khead) { + Tensor mCos_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(cos_ptr), Shape>{}), + Shape>{}); + Tensor mSin_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(sin_ptr), Shape>{}), + Shape>{}); + #pragma unroll + for (int k = 0; k < size<2>(tRgCos); ++k) { + int const ki = get<1>(tRcR(_0{}, _0{}, k)) / (kGmemElemsPerLoadCur); + if (tRpRCur(k)) { + cute::copy(GmemTiledCopyRo{}, mCos_copy(_, ki), tRrCos(_, m, k)); + cute::copy(GmemTiledCopyRo{}, mSin_copy(_, ki), tRrSin(_, m, k)); + } + } + } + } + return cute::make_tuple(tRrCos, tRrSin); + } + + template + CUTLASS_DEVICE + void + apply_Q_interleaved(TensorsQ &sQ, // (kBlockM, kHeadDim) + TensortRrR const &tRrCos, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary + TensortRrR const &tRrSin, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary + int const m_block, int const qhead_per_khead=1) + { + TiledCopyQK tiled_copy_q; + auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx); + Tensor tQsQ = gmem_thr_copy_q.partition_S(sQ); + Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tQsQ) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); + static_assert(decltype(size<0>(tQsQ))::value == decltype(size<0>(tRrCos))::value * 2); + static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + + #pragma unroll + for (int m = 0; m < size<1>(tQsQ); ++m) { + if (get<0>(tQcQ(_0{}, m, _0{})) < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tQsQ); ++k) { + if (tRpR(k)) { + Tensor rQ = make_fragment_like(tQsQ(_, m, k)); + cute::copy(tiled_copy_q, tQsQ(_, m, k), rQ); + apply_rotary_interleaved(rQ, tRrCos(_, m, k), tRrSin(_, m, k)); + cute::copy(tiled_copy_q, rQ, tQsQ(_, m, k)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + void + apply_Q_contiguous(TensorsQ &sQ, // (kBlockM, kHeadDim) + TensortRrR const &tRrCosCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont + TensortRrR const &tRrSinCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont + int const m_block, int const qhead_per_khead=1) + { + TiledCopyQK tiled_copy_q; + auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx); + Tensor sQ_copy = cute::tiled_divide(sQ, Shape<_1, Int>{}); + Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tQcQ) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tQcQ) == size<0>(tRrCosCont)); + static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + + #pragma unroll + for (int m = 0; m < size<1>(tQcQ); ++m) { + int const row = get<0>(tQcQ(_0{}, m, _0{})); + if (row < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tQcQ); ++k) { + int const col = get<1>(tQcQ(_0{}, _0{}, k)); + if (col < rotary_dim / 2) { + int const col_idx_left = col / kGmemElemsPerLoad; + int const col_idx_right = col / kGmemElemsPerLoad + rotary_dim / (2 * kGmemElemsPerLoad); + Tensor rQ_left = make_fragment_like(sQ_copy(_, row, col_idx_left)); + cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_left), rQ_left); + Tensor rQ_right = make_fragment_like(rQ_left); + cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_right), rQ_right); + apply_rotary_contiguous(rQ_left, rQ_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k)); + cute::copy(tiled_copy_q, rQ_left, sQ_copy(_, row, col_idx_left)); + cute::copy(tiled_copy_q, rQ_right, sQ_copy(_, row, col_idx_right)); + } + } + } + } + }; + + template + CUTLASS_DEVICE + void + apply_K_interleaved(TensorsK const &sK, // (kBlockN, kHeadDim) + TensorgK &gK, // (kBlockN, kHeadDim) + TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV + TensortRrR const &tRrCos, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary + TensortRrR const &tRrSin, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary + TensorKPtr const &tPrKPtr, + int const n_block) + { + TiledCopyQK tiled_copy_k; + auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx); + Tensor tKsK = gmem_thr_copy_q.partition_S(sK); + Tensor tKgK = gmem_thr_copy_q.partition_S(gK); + Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tKsK) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrCos)); + CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrSin)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); + static_assert(decltype(size<0>(tKsK))::value == decltype(size<0>(tRrCos))::value * 2); + static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + if constexpr (PagedKV) { + static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); + } + + #pragma unroll + for (int m = 0; m < size<1>(tKsK); ++m) { + int const row = get<0>(tKcK(_0{}, m, _0{})); + auto mK_cur_copy = [&] { + if constexpr (PagedKV) { + Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + return cute::tiled_divide(mK_cur, Shape>{}); + } else { + return nullptr; + } + }(); + if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tKsK); ++k) { + if (tKpK(k)) { + Tensor rK = make_fragment_like(tKsK(_, m, k)); + cute::copy(tiled_copy_k, tKsK(_, m, k), rK); + if (tRpR(k)) { apply_rotary_interleaved(rK, tRrCos(_, m, k), tRrSin(_, m, k)); } + if constexpr (!PagedKV) { + cute::copy(tiled_copy_k, rK, tKgK(_, m, k)); + } else { + int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(tiled_copy_k, rK, mK_cur_copy(_, ki)); + } + } + } + } + } + }; + + template + CUTLASS_DEVICE + void + apply_K_contiguous(TensorsK const &sK, // (kBlockN, kHeadDim) + TensorgK &gK, // (kBlockN, kHeadDim) + TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV + TensortRrR const &tRrCosCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont + TensortRrR const &tRrSinCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont + TensorKPtr const &tPrKPtr, + int const n_block, int const max_k) + { + TiledCopyQK tiled_copy_k; + auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx); + Tensor sK_copy = cute::tiled_divide(sK, Shape<_1, Int>{}); + Tensor gK_copy = cute::tiled_divide(gK, Shape<_1, Int>{}); + Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); + + CUTE_STATIC_ASSERT_V(rank(tKcK) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{}); + CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{}); + CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrCosCont)); + CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); + CUTE_STATIC_ASSERT_V(size<0>(tKcK) == size<0>(tRrCosCont)); + static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + if constexpr (PagedKV) { + static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); + } + + const int ro_dim_vec = rotary_dim / kGmemElemsPerLoad; + const int non_ro_dim_vec = (max_k - rotary_dim) / kGmemElemsPerLoad; + #pragma unroll + for (int m = 0; m < size<1>(tKcK); ++m) { + int const row = get<0>(tKcK(_0{}, m, _0{})); + Tensor gK_cur_copy = [&] { + if constexpr (PagedKV) { + Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); + Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); + return cute::tiled_divide(mK_cur, Shape>{}); + } else { + return gK_copy(_, row, _); + } + }(); + if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) { + #pragma unroll + for (int k = 0; k < size<2>(tKcK); ++k) { + if (tKpK(k)) { + int const col = get<1>(tKcK(_0{}, _0{}, k)); + bool rotate = col < rotary_dim / 2; + int const col_idx_left = rotate ? col / kGmemElemsPerLoad : (col + rotary_dim / 2) / kGmemElemsPerLoad; + int const col_idx_right = col_idx_left + (rotate ? ro_dim_vec / 2 : non_ro_dim_vec / 2); + Tensor rK_left = make_fragment_like(sK_copy(_, row, col_idx_left)); + cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_left), rK_left); + Tensor rK_right = make_fragment_like(rK_left); + cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_right), rK_right); + if (rotate) { + apply_rotary_contiguous(rK_left, rK_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k)); + } + cute::copy(tiled_copy_k, rK_left, gK_cur_copy(_, col_idx_left)); + if (col_idx_right * kGmemElemsPerLoad < max_k) { + cute::copy(tiled_copy_k, rK_right, gK_cur_copy(_, col_idx_right)); + } + } + } + } + } + }; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/hopper/src/include/seqlen.h b/hopper/src/include/seqlen.h new file mode 100644 index 0000000..21a7471 --- /dev/null +++ b/hopper/src/include/seqlen.h @@ -0,0 +1,93 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +// We consolidate all the info related to sequence length here. This is so that we can do all +// the gmem reads once at the beginning of each tile, rather than having to repeat these reads +// to compute various things like n_block_min, n_block_max, etc. + +template +struct SeqlenInfo { + + int const offset, offset_padded; + int const seqlen; + + CUTLASS_DEVICE + SeqlenInfo(int const bidb, int const seqlen_static, int const* const cu_seqlens, int const* const seqused) + : offset(!Varlen || cu_seqlens == nullptr ? 0 : cu_seqlens[bidb]) + , offset_padded(!Varlen || cu_seqlens == nullptr ? 0 : (cu_seqlens[bidb] + bidb * kBlock) / kBlock * kBlock) + , seqlen(!Varlen + ? seqlen_static + : (seqused ? seqused[bidb] : (cu_seqlens ? cu_seqlens[bidb + 1] - cu_seqlens[bidb] : seqlen_static))) + { + } + +}; + +template +struct SeqlenInfoQK { + + int const offset_q, offset_k, offset_q_padded; + int const seqlen_q, seqlen_k; + + CUTLASS_DEVICE + SeqlenInfoQK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, + int const* const cu_seqlens_q, int const* const cu_seqlens_k, + int const* const seqused_q, int const* const seqused_k + ) + : offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) + , offset_k(!Varlen || cu_seqlens_k == nullptr ? 0 : cu_seqlens_k[bidb]) + // If varlen, the layout for dPSum, LSE_log2, and dQaccum is that we pad each sequence in the batch + // by an extra kBlockM, so that the write for each sequence doesn't touch the next sequence. + // Sequence i starts at cu_seqlens[i] + i * kBlockM and ends at cu_seqlens[i + 1] + i * kBlockM + // However, the start must align to multiples of kBlockM. + , offset_q_padded(!Varlen || cu_seqlens_q == nullptr ? 0 : (cu_seqlens_q[bidb] + bidb * kBlockM) / kBlockM * kBlockM) + , seqlen_q(!Varlen + ? seqlen_q_static + : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static))) + , seqlen_k(!Varlen + ? seqlen_k_static + : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static))) + { + } + +}; + +template +struct SeqlenInfoQKNewK { + + static_assert(!(AppendKV && !Varlen), "AppendKV is only supported with Varlen"); + + int const leftpad_k; + int const offset_q, offset_k, offset_k_new; + int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k; + + CUTLASS_DEVICE + SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0, + int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, + int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k + ) + : leftpad_k(ptr_leftpad_k ? ptr_leftpad_k[bidb] : 0) + , offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) + , offset_k(!Varlen ? 0 : (cu_seqlens_k ? cu_seqlens_k[bidb] : 0) + leftpad_k) + , offset_k_new(!AppendKV || cu_seqlens_k_new == nullptr ? 0 : cu_seqlens_k_new[bidb]) + , seqlen_q(!Varlen + ? seqlen_q_static + : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static))) + , seqlen_k_og(!Varlen + ? seqlen_k_static + : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)) - leftpad_k) + , seqlen_k_new(!AppendKV + ? 0 + : (cu_seqlens_k_new ? cu_seqlens_k_new[bidb + 1] - cu_seqlens_k_new[bidb] : shape_K_new_0)) + , seqlen_k(!AppendKV ? seqlen_k_og : seqlen_k_og + seqlen_k_new) + { + } + +}; + +} // namespace flash diff --git a/hopper/src/include/sm90_pipeline_no_cluster.hpp b/hopper/src/include/sm90_pipeline_no_cluster.hpp new file mode 100644 index 0000000..65a3d15 --- /dev/null +++ b/hopper/src/include/sm90_pipeline_no_cluster.hpp @@ -0,0 +1,99 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +namespace cutlass { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// As of Cutlass v3.6.0, if size(ClusterShape) == 1, PipelineTmaAsync has all threads +// signaling the barrier during consumer_release. This causes a perf regression in FA3 +// forward pass (especially hdim 128 causal). We instead reimplement the version of +// PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. +// +// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 +template > +class PipelineTmaAsyncNoCluster: public Base { +public: + using FullBarrier = typename Base::FullBarrier; + using EmptyBarrier = typename Base::EmptyBarrier; + static constexpr uint32_t Stages = Stages_; + using PipelineState = typename Base::PipelineState; + + using SharedStorage = typename Base::SharedStorage; + using ThreadCategory = typename Base::ThreadCategory; + using Params = typename Base::Params; + + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params) { + int warp_idx = canonical_warp_idx_sync(); + bool is_initializing_warp = (warp_idx == 0); + if (is_initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + template + CUTLASS_DEVICE + PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : Base(storage, params, make_shape(_1{}, _1{}, _1{}) /*cluster_shape*/, cute::false_type{} /*init_barriers*/, cute::false_type{} /*init_masks*/) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + int warp_idx = canonical_warp_idx_sync(); + int lane_predicate = cute::elect_one_sync(); + + static_assert(cute::is_same_v || cute::is_same_v); + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params); + } + + } + + // Constructor + template + CUTLASS_DEVICE + PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape) + : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, cute::true_type{}, cute::true_type{}) { } + + template + CUTLASS_DEVICE + PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}) + : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, InitBarriers{}, cute::true_type{}) { } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index()); + } + +private: + EmptyBarrier* const empty_barrier_ptr_ = nullptr; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip = false) { + empty_barrier_ptr_[stage].arrive(0 /*dst_blockid_*/, uint32_t(threadIdx.x % cutlass::NumThreadsPerWarpGroup == 0) & (!skip) /*is_signaling_thread*/); + } + +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cutlass diff --git a/hopper/src/include/softmax.h b/hopper/src/include/softmax.h new file mode 100644 index 0000000..8fcdb6b --- /dev/null +++ b/hopper/src/include/softmax.h @@ -0,0 +1,170 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ni++) { + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init && ni == 0 ? tensor(mi, ni) : op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); + if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + // For FP8, we can subtract max by 8.0 so that the value after exp2 is in the range of [0, 256]. + // This lets us use more of the FP8 range (instead of just [0, 1]) to reduce underflow. + static constexpr float max_offset = float(Max_offset); // We can only template on int, not float + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = Check_inf + ? (max(mi) == -INFINITY ? 0.f : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset) + : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)). This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + float const softmax_scale_log2; + + CUTLASS_DEVICE Softmax(float const softmax_scale_log2_) : softmax_scale_log2(softmax_scale_log2_) {}; + + template + __forceinline__ __device__ TensorT max_get_scale(Tensor0 &acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows); + TensorT scores_scale; + if constexpr (Is_first) { + flash::template reduce_max(scores, row_max); + cute::fill(scores_scale, 1.f); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale(mi); + } + } + return scores_scale; + }; + + template + __forceinline__ __device__ void online_softmax(Tensor0 &acc_s) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows); + flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + }; + + __forceinline__ __device__ TensorT finalize(float const final_scale=1.f) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT scores_scale; + #pragma unroll + for (int mi = 0; mi < size(row_sum); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum; + scores_scale(mi) = inv_sum * final_scale; + // For FP8, we might have scaled the output of exp by 2**8 so we need to divide sum by that amount. + if constexpr (Max_offset != 0) { + static constexpr float sum_scale = 1.f / float(1 << Max_offset); + sum *= sum_scale; + } + row_sum(mi) = (sum == 0.f || sum != sum) ? -INFINITY : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); + } + return scores_scale; + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) { + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(CUTE_STATIC_V(size<0>(acc_o_rowcol)) == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); } + } + }; + +}; + +} // namespace flash diff --git a/hopper/src/include/static_switch.h b/hopper/src/include/static_switch.h new file mode 100644 index 0000000..5e13b5f --- /dev/null +++ b/hopper/src/include/static_switch.h @@ -0,0 +1,181 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +// + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#ifdef FLASHATTENTION_DISABLE_LOCAL + #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \ + [&] { \ + constexpr static bool LOCAL_CONST_NAME = false; \ + if (CAUSAL_COND) { \ + constexpr static bool CAUSAL_CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CAUSAL_CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() +#else + #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \ + [&] { \ + if (CAUSAL_COND) { \ + constexpr static bool CAUSAL_CONST_NAME = true; \ + constexpr static bool LOCAL_CONST_NAME = false; \ + return __VA_ARGS__(); \ + } else if (LOCAL_COND) { \ + constexpr static bool CAUSAL_CONST_NAME = false; \ + constexpr static bool LOCAL_CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CAUSAL_CONST_NAME = false; \ + constexpr static bool LOCAL_CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() +#endif + +#ifdef FLASHATTENTION_DISABLE_SOFTCAP + #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SOFTCAP_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_PAGEDKV + #define PAGEDKV_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define PAGEDKV_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_SPLIT + #define SPLIT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SPLIT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_APPENDKV + #define APPENDKV_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define APPENDKV_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_PACKGQA + #define PACKGQA_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define PACKGQA_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_VARLEN + #define VARLEN_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define VARLEN_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_CLUSTER + #define CLUSTER_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define CLUSTER_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_SM8x + #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \ + [&] { \ + constexpr static int ARCH_NAME = 90; \ + return __VA_ARGS__(); \ + }() +#else + #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \ + [&] { \ + if (ARCH == 86 || ARCH == 89) { \ + constexpr static int ARCH_NAME = 86; \ + return __VA_ARGS__(); \ + } else if (ARCH < 90) { \ + constexpr static int ARCH_NAME = 80; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int ARCH_NAME = 90; \ + return __VA_ARGS__(); \ + } \ + }() +#endif + +#ifndef FLASHATTENTION_ENABLE_VCOLMAJOR + #define VCOLMAJOR_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define VCOLMAJOR_SWITCH BOOL_SWITCH +#endif + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM == 64) { \ + constexpr static int kHeadSize = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 96) { \ + constexpr static int kHeadSize = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 128) { \ + constexpr static int kHeadSize = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 96) { \ + constexpr static int kHeadSize = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 256) { \ + constexpr static int kHeadSize = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/hopper/src/include/tile_scheduler.hpp b/hopper/src/include/tile_scheduler.hpp new file mode 100644 index 0000000..0b74d0e --- /dev/null +++ b/hopper/src/include/tile_scheduler.hpp @@ -0,0 +1,537 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cutlass/fast_math.h" +#include "cutlass/arch/barrier.h" + +#include "named_barrier.hpp" + +namespace flash { + +/////////////////////////////////////////////////////////////////////////////// + +// Host side kernel arguments +struct TileSchedulerArguments { + // num_head is num_head_q if not PackGQA, else num_head_k + int const num_blocks, num_head, num_batch, num_splits; + int const qhead_per_khead; + int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr + int const seqlen_k, headdim, element_size; // Used to calculate L2 swizzling + int* const tile_count_semaphore = nullptr; + int* const cu_seqlens = nullptr; + int* const seqused = nullptr; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class SingleTileScheduler { + +public: + + using SharedStorage = int; + + // Device side kernel params + struct Params { + int const num_blocks, num_head, num_batch, num_splits; + int const qhead_per_khead; + int const seqlen; + cutlass::FastDivmod nsplits_divmod; + int* const cu_seqlens; + int* const seqused; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits, + args.qhead_per_khead, args.seqlen, + cutlass::FastDivmod(!Split ? 1 : args.num_splits), + !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)}; + } + + struct WorkTileInfo { + int block_idx = 0; + int bidh = 0; + int bidb = 0; + bool is_valid_tile = false; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return is_valid_tile; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + if constexpr (!Split) { + return {block_idx, bidh, bidb, 0 /*split_idx*/}; + } else { + int split_idx; + int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); + return {block_idx, bidh_actual, bidb, split_idx}; + } + } + + }; + + CUTLASS_DEVICE + SingleTileScheduler(SharedStorage* const smem_scheduler) { } + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true}; + if constexpr (Varlen) { + int seqlen = params.seqused + ? params.seqused[work_info.bidb] + : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen); + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + work_info.is_valid_tile = work_info.block_idx * kBlock < seqlen; + } + return work_info; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {-1, -1, -1, false}; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class StaticPersistentTileScheduler { + +public: + + using SharedStorage = int; + + // Device side kernel params + struct Params { + int total_blocks; + cutlass::FastDivmod m_block_divmod, head_divmod; + cutlass::FastDivmod nsplits_divmod; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits), + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)), + cutlass::FastDivmod(!Split ? 1 : args.num_splits)}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); + int split_idx = 0; + if constexpr (Split) { + bidh = params.nsplits_divmod.divmod(split_idx, bidh); + } + return {block, bidh, bidb, split_idx}; + } + + }; + + CUTLASS_DEVICE + StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {}; + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {current_work.tile_idx + int(gridDim.x)}; + } + +}; + +template +class DynamicPersistentTileScheduler { + + // This scheduler targets the causal (or local) case where each tile takes different + // amount of time. We use longest-processing-time-first scheduling: + // the longest remaining tile is assigned to the first SM that's free. + // SM indicates they are free by incrementing a semaphore. + // However, we have to make sure K & V still fit into L2 cache, so we perform scheduling + // on "sections" of the head & batch dimension, each section consisting of e.g. 8 heads. + // This is the L2 swizzling part. The size of each section is precomputed based on the + // size of K & V and the L2 cache size. + + static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); + static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; + +public: + using SharedStorage = int; + +protected: + SharedStorage* const tile_count_smem; + +public: + + // Device side kernel params + struct Params { + int const total_blocks; + cutlass::FastDivmod const m_block_divmod, head_divmod; + cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; + cutlass::FastDivmod const l2_minor_residual_divmod; + int const num_hb_quotient; + int* const tile_count_semaphore; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + int const size_one_kv_head = args.seqlen_k * args.headdim * args.element_size * 2; + int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V + // Swizzle is the size of each "section". Round swizzle to a power of 2 + // If not PackGQA already, the size of each section can increase by qhead_per_khead + int const swizzle = (1 << cutlass::find_log2(size_l2 / size_one_kv_head)) * (PackGQA ? 1 : args.qhead_per_khead); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; + int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits); + // printf("num_split_blocks = %d, num_head = %d, num_batch = %d, swizzle = %d, PackGQA = %d, qhead_per_khead = %d, num_hb_remainder = %d\n", num_split_blocks, args.num_head, args.num_batch, swizzle, int(PackGQA), args.qhead_per_khead, num_hb_remainder); + assert(args.tile_count_semaphore != nullptr); + return {num_split_blocks * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), + cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks), + // don't divide by 0 + cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), + (args.num_head * args.num_batch) / swizzle, + args.tile_count_semaphore}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int block, bidh, bidb; + int l2_mod, bidhb, bidhb_residual; + bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + if (bidhb < params.num_hb_quotient) { + block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); + } else { + block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); + } + bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); + int split_idx = 0; + if constexpr (Split) { + split_idx = params.m_block_divmod.divmod(block, block); + } + // Longest-processing-time-first + block = params.m_block_divmod.divisor - 1 - block; + return {block, bidh, bidb, split_idx}; + } + + }; + + CUTLASS_DEVICE + DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {}; + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const { + if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) { + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + } + } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + if (threadIdx.x % NumProducerThreads == 0) { + current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); + } + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + if constexpr (IsProducerWarp) { + // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 + int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + if (threadIdx.x % NumProducerThreads == 0) { + *tile_count_smem = current_work.tile_idx; + } + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + return {new_tile_idx}; + } else { + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + int tile_idx = *tile_count_smem; + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + return {tile_idx}; + } + } + +}; + + +template +class VarlenDynamicPersistentTileScheduler { + + static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); + static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; + +public: + using SharedStorage = int4; + +protected: + SharedStorage* const work_info_smem; + +public: + + // Device side kernel params + struct Params { + int num_head, num_batch; + int const qhead_per_khead; + int const seqlen; + cutlass::FastDivmod nsplits_divmod; + int* const tile_count_semaphore; + int* const cu_seqlens; + int* const seqused; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + // If Split, for the purpose of scheduling, we pretend that instead there are + // (args.num_splits * args.num_head) number of heads. + assert(args.tile_count_semaphore != nullptr); + return {args.num_head * (!Split ? 1 : args.num_splits), args.num_batch, + args.qhead_per_khead, args.seqlen, + cutlass::FastDivmod(!Split ? 1 : args.num_splits), + args.tile_count_semaphore, args.cu_seqlens, args.seqused}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx, block, bidh, bidb; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, bidb, params.num_batch); } + return bidb < params.num_batch; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + if constexpr (!Split) { + return {block, bidh, bidb, 0 /*split_idx*/}; + } else { + int split_idx; + int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); + return {block, bidh_actual, bidb, split_idx}; + } + } + }; + + CUTLASS_DEVICE + VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {}; + + CUTLASS_DEVICE + WorkTileInfo + tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const { + auto prefix_sum = [](int val) { + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) { + int32_t partial_sum = __shfl_up_sync(0xffffffff, val, i); + if (lane >= i) { val += partial_sum; } + } + return val; + }; + + auto get_num_m_blocks = [&](int bidb_start) { + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + int seqlen; + if (params.seqused) { + seqlen = lane + bidb_start < params.num_batch ? params.seqused[lane + bidb_start] : 0; + } else if (params.cu_seqlens) { + int cur_cu_seqlen = lane + bidb_start <= params.num_batch ? params.cu_seqlens[lane + bidb_start] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = params.seqlen; + } + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + return lane + bidb_start < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? cute::ceil_div(seqlen, kBlock) : 0; + }; + + int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane + // Cumulative number of blocks for the next 31 batches + int num_m_blocks_cumulative = prefix_sum(num_m_blocks); + // Total number of blocks for the next 31 batches + int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); + int group_end_tile = current_work.tile_idx - current_work.block - current_work.bidh * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + int bidb = current_work.bidb; + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); + // } + while (group_end_tile <= next_tile_idx) { + bidb += cutlass::NumThreadsPerWarp - 1; + if (bidb >= params.num_batch) { + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Returning early, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); + // } + return {next_tile_idx, 0, 0, params.num_batch}; + } + num_m_blocks = get_num_m_blocks(bidb); + num_m_blocks_cumulative = prefix_sum(num_m_blocks); + m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); + group_end_tile += m_blocks_in_group * params.num_head; + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Bottom of while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); + // } + } + int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head; + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx)); + bidb += batch_idx_in_group; + num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); + int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; + int bidh = mh_block / num_m_blocks; + int block = mh_block - bidh * num_m_blocks; + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("blockIdx.x = %d, threadIdx.x = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); + // } + return {next_tile_idx, block, bidh, bidb}; + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + if constexpr (IsProducerWarp) { + WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0}); + if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { + *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); + } + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + return work_info; + } else { + return get_next_work(params, {0, 0, 0, 0}); + } + } + + CUTLASS_DEVICE + void + init_consumer() const { + // Don't arrive at the TileCountSmemEmpty barrier here, because get_initial_work will do that + } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + if (threadIdx.x % NumProducerThreads == 0) { + current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); + } + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + if constexpr (IsProducerWarp) { + // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0 + int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); + WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb}; + work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info); + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { + *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); + } + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + return work_info; + } else { + flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + int4 work_info = *work_info_smem; + flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w}; + } + } + +}; + +} // flash diff --git a/hopper/src/include/tile_size.h b/hopper/src/include/tile_size.h new file mode 100644 index 0000000..5ff5dc3 --- /dev/null +++ b/hopper/src/include/tile_size.h @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +// Return {kBlockM, kBlockN, Mma1_is_RS, IntraWGOverlap} +constexpr std::tuple tile_size_fwd_sm90( + int headdim, bool is_causal, bool is_local, int element_size=2, + bool v_colmajor=false, bool paged_kv=false, bool softcap=false) { + if (element_size == 2) { + if (headdim <= 64) { + return {192, 128, true, true}; + // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen + // return {192, is_causal || is_local ? 192 : 176, true, false}; + } else if (headdim <= 96) { + return {192, is_local || paged_kv ? 128 : 144, false, true}; + } else if (headdim <= 128) { + // return {128, is_causal || is_local || paged_kv ? 128 : 176, true, true}; + return {64, 128, true, false}; + + // {128, 192, false, false} and {192, 128, false, true} are quite good too + // 128 x 192 hits the limit of smem if Mma1_is_RS, 128 x 144 hits the limit if !Mma1_is_RS + } else if (headdim <= 192) { + return {128, paged_kv || is_local ? 96 : 112, true, true}; // 128 x 112 hits the limit of smem + } else { + return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem + } + } else { + if (headdim <= 64) { + return {192, 160, true, true}; + } else if (headdim <= 96) { + return {192, 128, true, true}; + } else if (headdim <= 128) { + return {128, paged_kv ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; + } else if (headdim <= 192) { + return {128, (paged_kv || softcap) && is_local ? 128 : 160, true, true}; + } else { + return {128, is_local ? 64 : 128, true, !paged_kv}; // PagedKV uses more registers so we disabled IntraWGOverlap + } + } +} + +// Return {kBlockM, kBlockN, kNWarps, kStages, Q_in_regs} +constexpr std::tuple tile_size_fwd_sm8x( + bool sm86_or_89, int headdim, bool is_causal, bool is_local, int element_size=2, + bool paged_kv=false, bool varlen_and_split=false, + bool softcap=false, bool append_kv=false) { + if (element_size == 2) { + if (headdim <= 64) { + return {128, varlen_and_split ? 80 : (is_local ? 96 : 112), 4, 1, false}; + } else if (headdim <= 96) { + return {128, varlen_and_split || is_local ? 48 : 64, 4, 1, false}; + } else if (headdim <= 128) { + bool const use_8_warps = sm86_or_89 | varlen_and_split; + return {128, use_8_warps ? (varlen_and_split ? (is_local ? 96 : 112) : (is_local ? 96 : 128)) : (is_local ? 48 : 64), use_8_warps ? 8 : 4, 1, use_8_warps}; + } else if (headdim <= 192) { + bool const kBlockN_64 = append_kv || is_local || varlen_and_split || paged_kv; + return {128, kBlockN_64 ? 64 : 96, 8, sm86_or_89 ? 1 : 2, !kBlockN_64}; + } else { + return {128, sm86_or_89 ? (append_kv ? 32 : (varlen_and_split || is_local ? 48 : 64)) : (append_kv ? 48 : (varlen_and_split || is_local ? 64 : 96)), 8, 1, sm86_or_89 && !append_kv}; + } + } else { + // Placeholder for now + return {128, 64, 8, 2, false}; + } +} diff --git a/hopper/src/include/utils.h b/hopper/src/include/utils.h new file mode 100644 index 0000000..3b727e6 --- /dev/null +++ b/hopper/src/include/utils.h @@ -0,0 +1,599 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include + +#include +#include +#include +#include + +#define DEBUG 0 +#define DEBUG_LOAD 0 +#define DEBUG2 0 + +#define PRINT(name, content) \ + print(name); \ + print(" : "); \ + print(content); \ + print("\n"); + +#define PRINTTENSOR(name, content) \ + print(name); \ + print(" : "); \ + print_tensor(content); \ + print("\n"); + + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// A wrapper for the kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// Adapted from https://github.com/vllm-project/vllm/blob/4d29e91be84d27ca313d657eee92c067439a4c23/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh#L55 +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_sm80_to_sm89 : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 890) + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +template +struct MinOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; } +}; + +template <> +struct MinOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) +template +CUTLASS_DEVICE auto convert_layout_acc_rowcol(Layout0 acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) +// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) +template +CUTLASS_DEVICE auto convert_layout_acc_Aregs(Layout0 acc_layout) { + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { + auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) + return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } else { + static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); + static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); + static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); + auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) + // This combines the first two modes (<0, 0> and <0, 1>) into one mode. + // Will require register shuffling later to be correct. + return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), + get<1>(acc_layout), + coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) + // This combination is right but doesn't work with register shuffling. + // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), + // get<1>(acc_layout), + // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE auto convert_type_unsafe(Tensor const &tensor) { + using From_type = typename Engine::value_type; + static constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + // Unsafe because we're returning a tensor with memory allocated on the stack. If the compiler does not + // inline this function, then the memory might not be valid. +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { + // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. + using From_type = typename Engine::value_type; + using To_type = typename EngineOut::value_type; + static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); + static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); + Tensor frag = recast const>(tensor); + Tensor out_frg = recast>(out); + static_assert(size(frag) == size(out_frg)); + cutlass::NumericArrayConverter convert_op; + #pragma unroll + for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE +auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) { + if constexpr (A) { + return mma.partition_fragment_A(tensor0); + } else { + return mma.partition_fragment_B(tensor0); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) { + if constexpr (M_slice >= 0) { + static constexpr int MMA_M = decltype(size<1>(tCrC))::value; + static_assert(M_slice < MMA_M); + // After logical_divide, C has shape ((2,2,V), (MMA_M, 1), MMA_N) + Tensor tCrC_slice = cute::logical_divide(tCrC, Shape>{})(_, make_coord(Int{}, _), _); + if constexpr (!SwapAB) { + Tensor tCrA_slice = cute::logical_divide(tCrA, Shape>{})(_, make_coord(Int{}, _), _); + gemm(tiled_mma, tCrA_slice, tCrB, tCrC_slice); + } else { + Tensor tCrB_slice = cute::logical_divide(tCrB, Shape>{})(_, make_coord(Int{}, _), _); + gemm(tiled_mma, tCrA, tCrB_slice, tCrC_slice); + } + } else { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { + if constexpr (!SwapAB) { + warpgroup_fence_operand(const_cast(tCrA)); + } else { + warpgroup_fence_operand(const_cast(tCrB)); + } + } + warpgroup_fence_operand(tCrC); + warpgroup_arrive(); + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + if constexpr (!SwapAB) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } else { + cute::gemm(tiled_mma, tCrB(_,_,k_block), tCrA(_,_,k_block), tCrC); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { + if constexpr (!SwapAB) { + warpgroup_fence_operand(const_cast(tCrA)); + } else { + warpgroup_fence_operand(const_cast(tCrB)); + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B, Hook fn) { + if constexpr (SwapAB) { + gemm_sm80(acc, tCrB, tCrA, tCsB, tCsA, tiled_mma, smem_tiled_copy_B, smem_tiled_copy_A, smem_thr_copy_B, smem_thr_copy_A, fn); + } else { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + if constexpr (!std::is_same_v) { + if (i == 0) { fn(); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm_rs_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void copy(TiledCopy const &tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + // Decay TiledCopy to CopyAtom + auto copy_atom = static_cast(tiled_copy); + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + auto has_with_bool = cute::is_valid([](auto t)->void_t().with(true))>{}, copy_atom); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + bool predicate_mn = Is_even_MN || get<0>(identity_MN(_0{}, m, _0{})) < max_MN; + if constexpr (Is_even_MN || !Clear_OOB_MN) { + if (Is_even_MN || predicate_mn) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if constexpr (Is_even_K || !Clear_OOB_K) { + if (Is_even_K || predicate_K(k)) { cute::copy(copy_atom, S(_, m, k), D(_, m, k)); } + } else { // Clear_OOB_K == true && Is_even_K == false + // If copy traits can be transformed with a predicate value, do it, otherwise branch here + if constexpr (has_with_bool) { + cute::copy(copy_atom.with(predicate_K(k)), S(_, m, k), D(_, m, k)); + } else { + if (predicate_K(k)) { + cute::copy(copy_atom, S(_, m, k), D(_, m, k)); + } else { + cute::clear(D(_, m, k)); + } + } + } + } + } + } else { // Clear_OOB_MN == true && Is_even_MN == false, also implies Clear_OOB_K == true + if constexpr (!has_with_bool) { + if (predicate_mn) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(copy_atom, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else { + cute::clear(D(_, m, _)); + } + } else { // combine the mn predicate with the k predicate + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + cute::copy(copy_atom.with(predicate_mn && (Is_even_K || predicate_K(k))), S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Byte permute and shuffle to match register layout of +// (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II. +template +CUTLASS_DEVICE void permute_Aregs_fp8(Fragment &frag) { + // frag has shape ((4, 2, 2), MMA_M, MMA_N), each element is 8 bits + static_assert(decltype(size<0, 0>(frag))::value == 4); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(decltype(stride<0, 1>(frag))::value == 4); + static_assert(sizeof(typename Fragment::value_type) == 1); + + int quad_idx = threadIdx.x % 4; + bool lane_03 = quad_idx == 0 || quad_idx == 3; + int selector_upper = lane_03 ? 0x5410 : 0x1054; + int selector_lower = lane_03 ? 0x7632 : 0x3276; + + static constexpr int upper_map[4] = {0, 3, 1, 2}; + // static constexpr int lower_map[4] = {1, 2, 0, 3}; + + Tensor frag_64b = recast(frag); // ((1, 1, 2), MMA_M, MMA_N) + #pragma unroll + for (int i = 0; i < size(frag_64b); ++i) { + uint32_t upper = frag_64b[i].x; + uint32_t lower = frag_64b[i].y; + uint32_t upper0 = lane_03 ? upper : lower; + uint32_t lower0 = lane_03 ? lower : upper; + upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4); + // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4); + lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 1, 4); + frag_64b[i].x = __byte_perm(upper0, lower0, selector_upper); + frag_64b[i].y = __byte_perm(upper0, lower0, selector_lower); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag_64b); ++mi) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { + cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_output_fp8(Fragment &out) { + // out has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(out))::value == 2); + static_assert(decltype(size<0, 1>(out))::value == 2); + static_assert(decltype(size<0, 2>(out))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(out))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag = group_modes<1, 3>(out); // ((2, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag); ++mi) { + #pragma unroll + for (int j = 0; j < size<0, 1>(frag); ++j) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag) / 2; ++i) { + cutlass::swap(frag(make_coord(_1{}, j, 2 * i), mi), frag(make_coord(_0{}, j, 2 * i + 1), mi)); + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_output_fp8_Vcolmajor(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 16 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 2 || sizeof(typename Fragment::value_type) == 4); + + int quad_idx = threadIdx.x % 4; + bool lane_03 = quad_idx == 0 || quad_idx == 3; + + static constexpr int upper_map[4] = {0, 2, 3, 1}; + // static constexpr int lower_map[4] = {2, 0, 1, 3}; + + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); } + using type2 = std::conditional_t; + Tensor frag_2 = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + // if (blockIdx.x == 0 && threadIdx.x == 128) { print(frag); printf("\n"); print(frag_2); } + #pragma unroll + for (int mi = 0; mi < size<1>(frag_2); ++mi) { + #pragma unroll + for (int j = 0; j < size<0, 1>(frag_2); ++j) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_2) / 2; ++i) { + type2 upper = frag_2(make_coord(_0{}, j, 2 * i), mi); + type2 lower = frag_2(make_coord(_0{}, j, 2 * i + 1), mi); + type2 upper0 = lane_03 ? upper : lower; + type2 lower0 = lane_03 ? lower : upper; + upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4); + // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4); + lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 2, 4); + frag_2(make_coord(_0{}, j, 2 * i), mi) = lane_03 ? upper0 : lower0; + frag_2(make_coord(_0{}, j, 2 * i + 1), mi) = lane_03 ? lower0 : upper0; + } + } + } + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void apply_softcap(Tensor &tensor, float const softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template +CUTLASS_DEVICE auto calculate_dtanh(Tensor &tensor){ + Tensor out = make_fragment_like(tensor); + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + out(i) = 1.f - (tensor(i) * tensor(i)); + } + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_DEVICE +int canonical_warp_group_idx_nosync() { + return threadIdx.x / cutlass::NumThreadsPerWarpGroup; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/hopper/src/test_single_decode.cu b/hopper/src/test_single_decode.cu new file mode 100644 index 0000000..35cfb4e --- /dev/null +++ b/hopper/src/test_single_decode.cu @@ -0,0 +1,114 @@ +#include +#include + +#include "flash_api.h" + +torch::Tensor single_mha(torch::Tensor& q, torch::Tensor& k, torch::Tensor& v, int head_dim) { + const float sm_scale = 1.f / std::sqrt(float(head_dim)); + auto scaled_q = q * sm_scale; + + auto scores = torch::einsum("bthd,bshd->bhts", {scaled_q, k}); + auto attention = torch::softmax(scores, -1).to(v.dtype()); + auto output = torch::einsum("bhts,bshd->bthd", {attention, v}); + return output; +} + +template +void TestDecodingKernelCorrectness(int seqlen_kv, const std::string& quant_mode, const int group_size) { + torch::manual_seed(42); + + const int bs = 1; + const int seqlen_q = 1; + const int pack_nums = 16 / num_bits; + + torch::Tensor Q_host = torch::rand({bs, seqlen_q, num_heads, head_dim}, torch::dtype(torch::kHalf)); + torch::Tensor K_host = torch::randn({bs, seqlen_kv, num_heads_kv, head_dim}, torch::dtype(torch::kHalf)); + torch::Tensor V_host = torch::randn({bs, seqlen_kv, num_heads_kv, head_dim}, torch::dtype(torch::kHalf)); + + torch::Tensor Q_device = Q_host.to(torch::kCUDA); + torch::Tensor K_device = K_host.to(torch::kCUDA); + torch::Tensor V_device = V_host.to(torch::kCUDA); + + at::Tensor k_pack, k_params, v_pack, v_params; + if (quant_mode == "k-channel") { + k_pack = torch::empty({bs, seqlen_kv / pack_nums, num_heads_kv, head_dim}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); + k_params = torch::empty({bs, seqlen_kv / group_size, num_heads_kv, head_dim}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); + } else { + k_pack = torch::empty({bs, seqlen_kv, num_heads_kv, head_dim / pack_nums}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); + k_params = torch::empty({bs, head_dim / group_size, num_heads_kv, seqlen_kv}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); + } + v_pack = torch::empty({bs, seqlen_kv, num_heads_kv, head_dim / pack_nums}, torch::dtype(torch::kUInt16)).to(torch::kCUDA); + v_params = torch::empty({bs, head_dim / group_size, num_heads_kv, seqlen_kv}, torch::dtype(torch::kFloat32)).to(torch::kCUDA); + + // Convert K, V to unpadded format + torch::Tensor K_unpad = K_device.reshape({bs * seqlen_kv, num_heads_kv, head_dim}); + torch::Tensor V_unpad = V_device.reshape({bs * seqlen_kv, num_heads_kv, head_dim}); + + auto cu_seqlens_k = torch::arange(0, (bs + 1) * seqlen_kv, seqlen_kv, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); + std::optional opt_block_table = std::nullopt; + + kvcache_qpack( + K_unpad, k_pack, k_params, + V_unpad, v_pack, v_params, + opt_block_table, + cu_seqlens_k, + seqlen_kv, + quant_mode, + group_size + ); + + K_device = K_device.to(torch::kCPU); + V_device = V_device.to(torch::kCPU); + + const float sm_scale = 1 / std::sqrt(float(head_dim)); + auto out = mha_fwd_kvcache(Q_device, + K_device, k_pack, k_params, + V_device, v_pack, v_params, + sm_scale); + + auto out_cpu = out.to(torch::kCPU); + + auto out_ref = single_mha(Q_host, K_host, V_host, head_dim); + + // Compute the difference + torch::Tensor diff = out_cpu - out_ref; + float mean_absolute_error = diff.abs().mean().item(); + float mean_squared_error = diff.pow(2).mean().item(); + + printf("\nnum_heads: %d num_heads_kv: %d seqlen_kv: %d head_dim: %d\n", num_heads, num_heads_kv, seqlen_kv, head_dim); + if (mean_absolute_error < 1e-1 && mean_squared_error < 1e-1) { + printf("test pass ! \n"); + printf("mean_absolute_error: %f, mean_squared_error: %f\n", mean_absolute_error, mean_squared_error); + } else { + printf("test fail ! \n"); + printf("mean_absolute_error: %f, mean_squared_error: %f\n", mean_absolute_error, mean_squared_error); + } + + printf("\nFirst 10 elements of out_cpu:\n"); + auto out_cpu_accessor = static_cast<__half*>(out_cpu.flatten().data_ptr()); + for (int i = 0; i < 10; i++) { + printf("%.6f ", static_cast(out_cpu_accessor[i])); + } + + printf("\n\nFirst 10 elements of out_ref:\n"); + auto out_ref_accessor = static_cast<__half*>(out_ref.flatten().data_ptr()); + for (int i = 0; i < 10; i++) { + printf("%.6f ", static_cast(out_ref_accessor[i])); + } + printf("\n"); +} + +int main() { + const int num_heads = 32; + const int num_heads_kv = 32; + const int head_dim = 128; + const int num_bits = 4; + const std::string quant_mode = "k-tensor"; + const int group_size = 128; + + int seqlen_kv = 1024; + + TestDecodingKernelCorrectness(seqlen_kv, quant_mode, group_size); + + return 0; +} \ No newline at end of file