15#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
29 template <
typename Problem>
32 constexpr index_t kBlockSize = Problem::kBlockSize;
33 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
34 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
36 constexpr index_t MaxVectorSize = 16 /
sizeof(
typename Problem::QDataType);
39 constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
40 static_assert(0 < ElemPerThread);
41 return min(ElemPerThread, MaxVectorSize);
44 template <
typename Problem>
49 return static_cast<index_t>(16 /
sizeof(OaccDataType));
52 template <
typename Problem>
55 constexpr index_t kBlockSize = Problem::kBlockSize;
56 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
57 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
59 constexpr index_t MaxVectorSize = 16 /
sizeof(
typename Problem::KDataType);
61 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
62 static_assert(0 < ElemPerThread);
63 return min(ElemPerThread, MaxVectorSize);
66 template <
typename Problem>
69 constexpr index_t kBlockSize = Problem::kBlockSize;
70 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
71 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
73 constexpr index_t MaxVectorSize = 16 /
sizeof(
typename Problem::VDataType);
75 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
76 static_assert(0 < ElemPerThread);
77 return min(ElemPerThread, MaxVectorSize);
80 template <
typename Problem,
bool BypassLDS = false>
83 if constexpr(!BypassLDS)
85 constexpr index_t kBlockSize = Problem::kBlockSize;
86 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
87 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
89 constexpr index_t MaxVectorSize = 16 /
sizeof(
typename Problem::QDataType);
91 constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
92 static_assert(0 < ElemPerThread);
93 constexpr index_t kMaxVecLoad =
min(ElemPerThread, MaxVectorSize);
95 constexpr index_t KPerThread = kMaxVecLoad;
96 constexpr index_t KThreads = kKPerBlock / KPerThread;
99 constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
113 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
114 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
116 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<0>{});
117 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<1>{});
119 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
120 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
122 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
123 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
134 q_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
142 template <
typename Problem,
bool LoadOnce = false>
147 constexpr index_t kBlockSize = Problem::kBlockSize;
148 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
150 LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
152 constexpr index_t MaxVectorSize = 16 /
sizeof(KDataType);
153 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
155 constexpr index_t K1 =
min(MaxVectorSize, ElemPerThread);
156 constexpr index_t K0 = kKPerBlock / K1;
159 constexpr index_t N0 = kNPerBlock / (N2 * N1);
170 template <
typename Problem>
174 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
175 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
177 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<0>{});
178 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<1>{});
180 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
181 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
183 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
184 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
188 constexpr auto q_block_outer_dstr_encoding =
197 q_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
204 template <
typename Problem>
209 return static_cast<index_t>(16 /
sizeof(QDataType));
212 template <
typename Problem,
bool Xor = false>
215 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
216 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
220 constexpr auto q_lds_block_desc = [&]() {
223#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
224 constexpr auto LDSLayerSize = 256 /
sizeof(
typename Problem::QDataType);
225 constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock;
227 if constexpr(XorLengthFold > 1)
231 number<LDSLayerSize / kKPack>{},
238 q_lds_block_desc_naive,
241 number<LDSLayerSize / kKPack>{})),
247 q_lds_block_desc_permuted,
257 q_lds_block_desc_tmp,
277 q_lds_block_desc_naive,
279 number<kKPerBlock / kKPack>{})),
285 q_lds_block_desc_permuted,
303 return q_lds_block_desc;
306 template <
typename Problem,
bool LoadOnce = false,
bool Xor = false>
309 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
311 LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
315 constexpr auto k_lds_block_desc = [&]() {
318#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
319 constexpr auto LDSLayerSize = 256 /
sizeof(
typename Problem::KDataType);
320 constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock;
322 if constexpr(XorLengthFold > 1)
326 number<LDSLayerSize / kKPack>{},
333 k_lds_block_desc_naive,
336 number<LDSLayerSize / kKPack>{})),
342 k_lds_block_desc_permuted,
352 k_lds_block_desc_tmp,
372 k_lds_block_desc_naive,
374 number<kKPerBlock / kKPack>{})),
380 k_lds_block_desc_permuted,
398 return k_lds_block_desc;
401 template <
typename Problem,
bool Xor = false>
404 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
405 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
409 constexpr auto v_lds_block_desc = [&]() {
412 constexpr auto XorGroupSize =
413 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<0>{});
415#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
416 constexpr auto LDSLayerSize = 256 /
sizeof(
typename Problem::VDataType);
417 constexpr auto XorLengthFold = LDSLayerSize / kNPerBlock;
419 if constexpr(XorLengthFold > 1)
423 number<LDSLayerSize / XorGroupSize>{},
430 v_lds_block_desc_naive,
433 number<LDSLayerSize / XorGroupSize>{})),
439 v_lds_block_desc_permuted,
443 number<kNPerBlock / XorGroupSize>{})),
449 v_lds_block_desc_tmp,
463 number<kNPerBlock / XorGroupSize>{},
470 v_lds_block_desc_naive,
478 v_lds_block_desc_permuted,
497 return v_lds_block_desc;
500 template <
typename Problem>
505 typename Problem::KDataType,
506 typename Problem::SaccDataType,
509 Problem::BlockFmhaShape::kN0,
510 Problem::BlockFmhaShape::kK0>,
511 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
512 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
515 typename Problem::KDataType,
516 typename Problem::SaccDataType,
517 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{}),
518 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<1>{}),
519 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<2>{}),
522 using BlockGemmPolicy =
524 typename Problem::KDataType,
525 typename Problem::SaccDataType,
526 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
533 template <
typename Problem>
538 typename Problem::VDataType,
539 typename Problem::OaccDataType,
542 Problem::BlockFmhaShape::kN1,
543 Problem::BlockFmhaShape::kK1>,
544 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
545 typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
549 typename Problem::VDataType,
550 typename Problem::OaccDataType,
551 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<0>{}),
552 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<1>{}),
553 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<2>{}),
557 ((Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<1>{}) == 16 &&
558 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<2>{}) == 32) ||
559 (Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<1>{}) == 32 &&
560 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<2>{}) == 16))
564 using BlockGemmPolicy =
566 typename Problem::VDataType,
567 typename Problem::OaccDataType,
568 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
575 template <
typename Problem>
579 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
580 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
582 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<0>{});
583 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<1>{});
585 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
586 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
588 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
589 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
593 constexpr auto k_block_outer_dstr_encoding =
602 k_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
609 template <
typename Problem>
612 constexpr index_t kBlockSize = Problem::kBlockSize;
613 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
614 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
616 constexpr index_t MaxVectorSize = 16 /
sizeof(
typename Problem::VDataType);
618 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
619 static_assert(0 < ElemPerThread);
620 constexpr index_t kMaxVecLoad =
min(ElemPerThread, MaxVectorSize);
622 constexpr index_t NPerThread = kMaxVecLoad;
623 constexpr index_t NThreads = kNPerBlock / NPerThread;
626 constexpr index_t KPerThread = kKPerBlock / (KThreadPerWarp * NumWarps);
638 template <
typename Problem>
642 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
643 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
645 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<0>{});
646 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<1>{});
648 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
649 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
651 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
652 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
656 constexpr auto p_block_outer_dstr_encoding =
665 p_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
672 template <
typename Problem>
676 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
677 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
679 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<0>{});
680 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<1>{});
682 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
683 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
685 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
686 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
690 constexpr auto v_block_outer_dstr_encoding =
699 v_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
701 constexpr auto v_block_dstr =
703 decltype(v_block_dstr_encode),
704 typename Problem::VDataType>::TransposedDstrEncode{});
709 template <
typename Problem>
713 return static_cast<index_t>(16 /
sizeof(SDataType));
716 template <
typename Problem>
719 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
720 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
737 return s_lds_block_desc;
740 template <
typename Problem>
745 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
747 constexpr index_t MWarp = config.template at<1>();
748 constexpr index_t NWarp = config.template at<2>();
752 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
753 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
754 constexpr index_t kTileK = Problem::BlockFmhaShape::kN0;
757 constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
758 constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane;
759 constexpr index_t K1 = kKPerBlock / (K2 * K3);
760 constexpr index_t K0 = kTileK / kKPerBlock;
761 constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
763 constexpr index_t M0 = kMPerBlock / (M2 * M1);
765 constexpr auto s2_block_dstr_encoding =
775 return s2_block_dstr;
778 template <
typename Problem>
782 sizeof(
typename Problem::QDataType);
785 template <
typename Problem,
bool LoadOnce = false>
789 sizeof(
typename Problem::KDataType);
792 template <
typename Problem>
796 sizeof(
typename Problem::VDataType);
799 template <
typename Problem>
802 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<1>{});
805 sizeof(
typename Problem::SaccDataType)
809 template <
typename Problem>
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
@ Single
Definition warp_gemm_attribute_mfma.hpp:14
@ Double
Definition warp_gemm_attribute_mfma.hpp:15
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
TransposeTileDistributionTraits< TileDistributionEncoding_, DataType_, Policy, true > InputTileDistributionTraits
Definition load_tile_transpose.hpp:343
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
@ KMN
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:12
@ MNK
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:13
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:23
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:67
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeQ()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:779
static CK_TILE_HOST_DEVICE constexpr auto MakeVRegTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:673
static CK_TILE_HOST_DEVICE constexpr auto MakeQDramTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:81
static CK_TILE_HOST_DEVICE constexpr auto MakeKRegTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:576
static CK_TILE_HOST_DEVICE constexpr auto GetPVBlockGemm()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:534
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackQ()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:205
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsBlockDescriptor()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:402
static CK_TILE_HOST_DEVICE constexpr auto GetSmemNPackS()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:710
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeS()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:800
static CK_TILE_HOST_DEVICE constexpr auto MakeQLdsBlockDescriptor()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:213
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:501
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentOacc()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:45
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsBlockDescriptor()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:307
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeK()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:786
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:53
BlockFmhaPipelineQXKSVSCustomPolicy< true, false, 1, 1 > BasePolicy
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:24
static CK_TILE_HOST_DEVICE constexpr auto MakeQRegTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:171
static CK_TILE_HOST_DEVICE constexpr auto MakeSLdsBlockDescriptor()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:717
static CK_TILE_HOST_DEVICE constexpr auto MakePRegTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:639
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeV()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:793
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:810
static CK_TILE_HOST_DEVICE constexpr auto MakeSRegTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:741
static CK_TILE_HOST_DEVICE constexpr auto MakeKDramTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:143
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:30
static CK_TILE_DEVICE constexpr auto MakeVDramTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:610
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:266
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackV()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:373
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackK()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:338
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:23
Definition block_gemm_areg_breg_creg_v2.hpp:17
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192