device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp Source File

device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp Source File#

Composable Kernel: device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp Source File
device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <functional>
7#include <iostream>
8#include <iterator>
9#include <numeric>
10#include <sstream>
11
26
27#ifdef CK_EXPERIMENTAL_BUILDER
28#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
29#endif
30
31namespace ck {
32namespace tensor_operation {
33namespace device {
34
35//
36// @brief Device Convolution operation.
37//
38// Supports:
39// @li Forward convolution with up to 3 spatial dimentions
40// @li Input tensor in GNWC data format
41// @li Weight tensor in GKXC data format
42// @li Output tensor in GNWK data format
43//
44// 1D:
45// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
46// 2D:
47// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
48// 3D:
49// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
50// Assume:
51// AK1 == BK1
52template <index_t NDimSpatial,
53 typename ALayout,
54 typename BLayout,
55 typename DsLayout,
56 typename ELayout,
57 typename ADataType,
58 typename BDataType,
59 typename AccDataType,
60 typename CShuffleDataType,
61 typename DsDataType,
62 typename EDataType,
63 typename AElementwiseOperation,
64 typename BElementwiseOperation,
65 typename CDEElementwiseOperation,
66 ConvolutionForwardSpecialization ConvForwardSpecialization,
67 GemmSpecialization GemmSpec,
68 index_t NumGemmKPrefetchStage,
69 ck::index_t BlockSize,
70 ck::index_t MPerBlock,
71 ck::index_t NPerBlock,
72 ck::index_t KPerBlock,
73 ck::index_t K1,
74 ck::index_t MPerWmma,
75 ck::index_t NPerWmma,
76 ck::index_t MRepeat,
77 ck::index_t NRepeat,
78 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
79 typename ABlockTransferThreadClusterArrangeOrder,
80 typename ABlockTransferSrcAccessOrder,
81 index_t ABlockTransferSrcVectorDim,
82 index_t ABlockTransferSrcScalarPerVector,
83 index_t ABlockTransferDstScalarPerVector_AK1,
84 bool ABlockLdsExtraM,
85 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
86 typename BBlockTransferThreadClusterArrangeOrder,
87 typename BBlockTransferSrcAccessOrder,
88 index_t BBlockTransferSrcVectorDim,
89 index_t BBlockTransferSrcScalarPerVector,
90 index_t BBlockTransferDstScalarPerVector_BK1,
91 bool BBlockLdsExtraN,
92 index_t CShuffleMRepeatPerShuffle,
93 index_t CShuffleNRepeatPerShuffle,
94 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
95 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
99 : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
100 ALayout,
101 BLayout,
102 DsLayout,
103 ELayout,
104 ADataType,
105 BDataType,
106 DsDataType,
107 EDataType,
108 AElementwiseOperation,
109 BElementwiseOperation,
110 CDEElementwiseOperation>
111{
113
114 static constexpr index_t NumDTensor = DsDataType::Size();
115
116 static constexpr auto I0 = Number<0>{};
117 static constexpr auto I1 = Number<1>{};
118 static constexpr auto I2 = Number<2>{};
119 static constexpr auto I3 = Number<3>{};
120 static constexpr auto I4 = Number<4>{};
121 static constexpr auto I5 = Number<5>{};
122 static constexpr auto I6 = Number<6>{};
123 // K1 = Max Vector Access Pixels
124 static constexpr auto K1Number = Number<K1>{};
125
126 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
127 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
128 static constexpr auto WmmaK = 16;
129
130 static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
131 static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
132
133 // If true, LDS is used unconditionally
134 static constexpr auto AEnableLds_manu = true;
135 static constexpr auto BEnableLds_manu = true;
136
137 static constexpr auto AEnableLds =
138 AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1);
139 static constexpr auto BEnableLds =
140 BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
141
143
144 static constexpr auto matrix_padder =
145 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
146
147 template <typename ALay>
148 static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
149 {
150 const auto in_gemmmraw_gemmkraw_desc =
151 conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
152
153 const auto in_gemmm_gemmk_desc =
154 matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
155
156 const auto M = in_gemmm_gemmk_desc.GetLength(I0);
157 const auto K = in_gemmm_gemmk_desc.GetLength(I1);
158 assert(K % K1 == 0);
159
160 if constexpr(AEnableLds)
161 {
162 const index_t K0 = K / K1;
163
165 in_gemmm_gemmk_desc,
170 }
171 else
172 {
173 constexpr auto A_KRow = 2;
174 constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
175 const auto A_KWmma = K / WmmaK;
176
177 const auto M0 = M / MPerBlock;
178 // 0 1 0 1 2 3 4 5 6
179 // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
181 in_gemmm_gemmk_desc,
185 make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
188 }
189 }
190
191 template <typename BLay>
192 static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
193 {
194 const auto wei_gemmnraw_gemmkraw_desc =
195 conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
196
197 const auto wei_gemmn_gemmk_desc =
198 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
199
200 const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
201 const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
202 assert(K % K1 == 0);
203
204 if constexpr(BEnableLds)
205 {
206 const index_t K0 = K / K1;
207
209 wei_gemmn_gemmk_desc,
214 }
215 else
216 {
217 constexpr auto B_KRow = 2;
218 constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
219 const auto B_KWmma = K / WmmaK;
220
221 const auto N0 = N / NPerBlock;
222 // 0 1 0 1 2 3 4 5 6
223 // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
225 wei_gemmn_gemmk_desc,
229 make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
232 }
233 }
234
235 template <typename ELay>
236 static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
237 {
238 const auto out_gemmmraw_gemmnraw_desc =
239 conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
240
241 const auto out_gemmm_gemmn_desc =
242 matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
243
244 return out_gemmm_gemmn_desc;
245 }
246
247 static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
248 {
249 return generate_tuple(
250 [&](auto i) {
251 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
252
253 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
254 },
256 }
257
258 // desc for problem definition
260 using AGridDesc =
262 using BGridDesc =
268
269 // GridwiseOp
271 // DataType Family
272 ADataType,
273 BDataType,
274 AccDataType,
275 CShuffleDataType,
276 DsDataType,
277 EDataType,
278 // InMemory Data Descriptor
279 AGridDesc,
280 BGridDesc,
283 // ElementwiseOp Family
284 AElementwiseOperation,
285 BElementwiseOperation,
286 CDEElementwiseOperation,
288 // Tiling Family
289 MPerBlock,
290 NPerBlock,
291 KPerBlock,
292 MPerWmma,
293 NPerWmma,
294 K1,
295 MRepeat,
296 NRepeat,
297 // ThreadCluster Family
298 BlockSize,
299 ABlockTransferThreadClusterLengths_AK0_M_AK1,
300 ABlockTransferThreadClusterArrangeOrder,
301 ABlockTransferSrcAccessOrder,
302 ABlockTransferSrcVectorDim,
303 ABlockTransferSrcScalarPerVector,
304 ABlockTransferDstScalarPerVector_AK1,
305 false,
307 ABlockLdsExtraM,
308 BBlockTransferThreadClusterLengths_BK0_N_BK1,
309 BBlockTransferThreadClusterArrangeOrder,
310 BBlockTransferSrcAccessOrder,
311 BBlockTransferSrcVectorDim,
312 BBlockTransferSrcScalarPerVector,
313 BBlockTransferDstScalarPerVector_BK1,
314 false,
316 BBlockLdsExtraN,
317 CShuffleMRepeatPerShuffle,
318 CShuffleNRepeatPerShuffle,
319 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
320 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
321 NumGemmKPrefetchStage,
322 LoopSched,
323 PipelineVer>;
324
325 // Argument
326 struct Argument : public BaseArgument
327 {
328 Argument(const void* p_a,
329 const void* p_b,
330 const std::array<const void*, NumDTensor>& p_ds,
331 void* p_e,
332 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
333 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
334 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
335 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
336 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
337 ds_g_n_k_wos_lengths,
338 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
339 ds_g_n_k_wos_strides,
340 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
341 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
342 const std::array<index_t, NDimSpatial>& conv_filter_strides,
343 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
344 const std::array<index_t, NDimSpatial>& input_left_pads,
345 const std::array<index_t, NDimSpatial>& input_right_pads,
346 index_t M01,
347 index_t N01,
348 const AElementwiseOperation& a_element_op,
349 const BElementwiseOperation& b_element_op,
350 const CDEElementwiseOperation& cde_element_op)
351 : p_a_grid_{static_cast<const ADataType*>(p_a)},
352 p_b_grid_{static_cast<const BDataType*>(p_b)},
353 p_ds_grid_{},
354 p_e_grid_{static_cast<EDataType*>(p_e)},
355 num_group_{a_g_n_c_wis_lengths[0]},
356 conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
357 a_g_n_c_wis_strides,
358 b_g_k_c_xs_lengths,
359 b_g_k_c_xs_strides,
360 e_g_n_k_wos_lengths,
361 e_g_n_k_wos_strides,
362 conv_filter_strides,
363 conv_filter_dilations,
364 input_left_pads,
365 input_right_pads},
373 block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
375 a_element_op_{a_element_op},
376 b_element_op_{b_element_op},
377 cde_element_op_{cde_element_op},
378 a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
379 a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
380 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
381 b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
382 ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
383 ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
384 e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
385 e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
386 conv_filter_strides_{conv_filter_strides},
387 conv_filter_dilations_{conv_filter_dilations},
388 input_left_pads_{input_left_pads},
389 input_right_pads_{input_right_pads}
390 {
391 // A/B/E Batch Stride
392 compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
393 compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
394 compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
395
396 // populate pointer, batch stride, desc for Ds
397 static_for<0, NumDTensor, 1>{}([&](auto i) {
398 // using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
399 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
400
401 // D pointer
402 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
403
404 // D batch stride
405 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
406 });
407
408 // D desc
410 [&](auto i) {
411 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
412
413 ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
414 a_g_n_c_wis_strides,
415 b_g_k_c_xs_lengths,
416 b_g_k_c_xs_strides,
417 ds_g_n_k_wos_lengths[i],
418 ds_g_n_k_wos_strides[i],
419 conv_filter_strides,
420 conv_filter_dilations,
421 input_left_pads,
422 input_right_pads};
423
424 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
425 },
427
428 // populate desc for Ds/E
434 }
435
436 void Print() const
437 {
438 std::cout << "A[M, K]: " << a_grid_desc_ << std::endl;
439 std::cout << "B[N, K]: " << b_grid_desc_ << std::endl;
441 [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
442 std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
443 }
444
445 // private:
446 // pointers
447 const ADataType* p_a_grid_;
448 const BDataType* p_b_grid_;
450 EDataType* p_e_grid_;
451
452 // tensor descriptors for problem definiton
454
456
459
460 // tensor descriptors for block/thread-wise copy
467
468 // block-to-e-tile map
470
471 // for computing batch offset
472 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
473
474 // element-wise op
475 AElementwiseOperation a_element_op_;
476 BElementwiseOperation b_element_op_;
477 CDEElementwiseOperation cde_element_op_;
478
479 // for checking IsSupportedArgument()
480 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
481 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
482 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
483 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
484 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
485 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
486 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
487 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
488 std::array<index_t, NDimSpatial> conv_filter_strides_;
489 std::array<index_t, NDimSpatial> conv_filter_dilations_;
490 std::array<index_t, NDimSpatial> input_left_pads_;
491 std::array<index_t, NDimSpatial> input_right_pads_;
492 };
493
494 // Invoker
495 struct Invoker : public BaseInvoker
496 {
498
499 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
500 {
501 if(stream_config.log_level_ > 0)
502 {
503 arg.Print();
504 }
505
506 const index_t grid_size =
507 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
508
509 const auto K = [&]() {
510 if constexpr(AEnableLds)
511 {
512 return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
513 }
514 else
515 {
516 return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
517 arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
518 }
519 }();
520
521 auto launch_kernel = [&](auto has_main_k_block_loop) {
522 constexpr bool has_main_loop = has_main_k_block_loop.value;
523
526 ADataType,
527 BDataType,
529 EDataType,
530 AElementwiseOperation,
531 BElementwiseOperation,
532 CDEElementwiseOperation,
538 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
539 has_main_loop>;
540
541 return launch_and_time_kernel(stream_config,
542 kernel,
543 dim3(grid_size),
544 dim3(BlockSize),
545 0,
546 arg.p_a_grid_,
547 arg.p_b_grid_,
548 arg.p_ds_grid_,
549 arg.p_e_grid_,
550 arg.a_element_op_,
551 arg.b_element_op_,
552 arg.cde_element_op_,
553 arg.a_g_n_c_wis_lengths_[0], // Group count
554 arg.a_grid_desc_,
555 arg.b_grid_desc_,
560 };
561
563 {
564 return launch_kernel(integral_constant<bool, true>{});
565 }
566 else
567 {
568 return launch_kernel(integral_constant<bool, false>{});
569 }
570 }
571
572 float Run(const BaseArgument* p_arg,
573 const StreamConfig& stream_config = StreamConfig{}) override
574 {
575 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
576 }
577 };
578
579 static bool IsSupportedArgument(const Argument& arg)
580 {
581 namespace ctc = tensor_layout::convolution;
582
583 // check device
585 {
587 {
588 return false;
589 }
590 }
591 else
592 {
593 return false;
594 }
595
596 // check ConvolutionForwardSpecialization
597 if constexpr(ConvForwardSpecialization ==
599 {
600 // check if it's 1x1, stride=1 conv
601 for(index_t i = 0; i < NDimSpatial; ++i)
602 {
603 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
604 const index_t ConvStride = arg.conv_filter_strides_[i];
605 const index_t LeftPad = arg.input_left_pads_[i];
606 const index_t RightPad = arg.input_right_pads_[i];
607
608 if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
609 {
610 return false;
611 }
612 }
613 }
614 else if constexpr(ConvForwardSpecialization ==
616 {
617 // check if it's 1x1 conv
618 for(index_t i = 0; i < NDimSpatial; ++i)
619 {
620 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
621 const index_t LeftPad = arg.input_left_pads_[i];
622 const index_t RightPad = arg.input_right_pads_[i];
623
624 if(!(X == 1 && LeftPad == 0 && RightPad == 0))
625 {
626 return false;
627 }
628 }
629 }
630
631 // check vector access of A
632 // FIXME: layout
638 {
639 const index_t C = arg.a_g_n_c_wis_lengths_[2];
640
641 if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
642 {
643 return false;
644 }
645 }
646 else
647 {
648 return false;
649 }
650
651 // check vector access of B
652 // FIXME: layout
658
659 {
660 const index_t C = arg.b_g_k_c_xs_lengths_[2];
661
662 if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
663 {
664 return false;
665 }
666 }
667 else
668 {
669 return false;
670 }
671
672 // check vector access of Ds
673 bool valid = true;
674
675 static_for<0, NumDTensor, 1>{}([&](auto i) {
676 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
677
678 // FIXME: layout
684 {
685 const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
686
687 if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
688 {
689 valid = false;
690 }
691 }
692 else
693 {
694 valid = false;
695 }
696 });
697
698 if(!valid)
699 {
700 return false;
701 }
702
703 // check vector access of E
709 {
710 const index_t K = arg.e_g_n_k_wos_lengths_[2];
711
712 if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
713 {
714 return false;
715 }
716 }
717 else
718 {
719 return false;
720 }
721
722 // check Gridwise GEMM
724 arg.b_grid_desc_,
728 }
729
730 bool IsSupportedArgument(const BaseArgument* p_arg) override
731 {
732 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
733 }
734
735 static auto MakeArgument(
736 const void* p_a,
737 const void* p_b,
738 const std::array<const void*, NumDTensor>& p_ds,
739 void* p_e,
740 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
741 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
742 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
743 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
744 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
745 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
746 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
747 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
748 const std::array<index_t, NDimSpatial>& conv_filter_strides,
749 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
750 const std::array<index_t, NDimSpatial>& input_left_pads,
751 const std::array<index_t, NDimSpatial>& input_right_pads,
752 const AElementwiseOperation& a_element_op,
753 const BElementwiseOperation& b_element_op,
754 const CDEElementwiseOperation& cde_element_op)
755 {
756 return Argument{p_a,
757 p_b,
758 p_ds,
759 p_e,
760 a_g_n_c_wis_lengths,
761 a_g_n_c_wis_strides,
762 b_g_k_c_xs_lengths,
763 b_g_k_c_xs_strides,
764 ds_g_n_k_wos_lengths,
765 ds_g_n_k_wos_strides,
766 e_g_n_k_wos_lengths,
767 e_g_n_k_wos_strides,
768 conv_filter_strides,
769 conv_filter_dilations,
770 input_left_pads,
771 input_right_pads,
772 1,
773 1,
774 a_element_op,
775 b_element_op,
776 cde_element_op};
777 }
778
779 static auto
780 MakeArgument(const void* p_a,
781 const void* p_b,
782 const std::array<const void*, NumDTensor>& p_ds,
783 void* p_e,
784 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
785 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
786 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
787 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
788 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
789 ds_g_n_k_wos_lengths,
790 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
791 ds_g_n_k_wos_strides,
792 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
793 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
794 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
795 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
796 const std::array<long_index_t, NDimSpatial>& input_left_pads,
797 const std::array<long_index_t, NDimSpatial>& input_right_pads,
798 const AElementwiseOperation& a_element_op,
799 const BElementwiseOperation& b_element_op,
800 const CDEElementwiseOperation& cde_element_op)
801 {
802 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
803 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
804 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
805 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
806 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
807 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
808 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
809 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
810 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
811 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
812 std::array<index_t, NDimSpatial> input_left_pads_i32;
813 std::array<index_t, NDimSpatial> input_right_pads_i32;
814
815 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
816 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
817 array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
818 array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
819 for(index_t d = 0; d < NumDTensor; d++)
820 {
821 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
822 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
823 }
824 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
825 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
826 array_convert(conv_filter_strides_i32, conv_filter_strides);
827 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
828 array_convert(input_left_pads_i32, input_left_pads);
829 array_convert(input_right_pads_i32, input_right_pads);
830
831 return Argument{p_a,
832 p_b,
833 p_ds,
834 p_e,
835 a_g_n_c_wis_lengths_i32,
836 a_g_n_c_wis_strides_i32,
837 b_g_k_c_xs_lengths_i32,
838 b_g_k_c_xs_strides_i32,
839 ds_g_n_k_wos_lengths_i32,
840 ds_g_n_k_wos_strides_i32,
841 e_g_n_k_wos_lengths_i32,
842 e_g_n_k_wos_strides_i32,
843 conv_filter_strides_i32,
844 conv_filter_dilations_i32,
845 input_left_pads_i32,
846 input_right_pads_i32,
847 1,
848 1,
849 a_element_op,
850 b_element_op,
851 cde_element_op};
852 }
853
854 static auto MakeInvoker() { return Invoker{}; }
855
856 std::unique_ptr<BaseArgument> MakeArgumentPointer(
857 const void* p_a,
858 const void* p_b,
859 const std::array<const void*, NumDTensor>& p_ds,
860 void* p_e,
861 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
862 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
863 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
864 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
865 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
866 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
867 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
868 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
869 const std::array<index_t, NDimSpatial>& conv_filter_strides,
870 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
871 const std::array<index_t, NDimSpatial>& input_left_pads,
872 const std::array<index_t, NDimSpatial>& input_right_pads,
873 const AElementwiseOperation& a_element_op,
874 const BElementwiseOperation& b_element_op,
875 const CDEElementwiseOperation& cde_element_op) override
876 {
877 return std::make_unique<Argument>(p_a,
878 p_b,
879 p_ds,
880 p_e,
881 a_g_n_c_wis_lengths,
882 a_g_n_c_wis_strides,
883 b_g_k_c_xs_lengths,
884 b_g_k_c_xs_strides,
885 ds_g_n_k_wos_lengths,
886 ds_g_n_k_wos_strides,
887 e_g_n_k_wos_lengths,
888 e_g_n_k_wos_strides,
889 conv_filter_strides,
890 conv_filter_dilations,
891 input_left_pads,
892 input_right_pads,
893 1,
894 1,
895 a_element_op,
896 b_element_op,
897 cde_element_op);
898 }
899
900 std::unique_ptr<BaseArgument>
901 MakeArgumentPointer(const void* p_a,
902 const void* p_b,
903 const std::array<const void*, NumDTensor>& p_ds,
904 void* p_e,
905 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
906 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
907 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
908 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
909 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
910 ds_g_n_k_wos_lengths,
911 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
912 ds_g_n_k_wos_strides,
913 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
914 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
915 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
916 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
917 const std::array<long_index_t, NDimSpatial>& input_left_pads,
918 const std::array<long_index_t, NDimSpatial>& input_right_pads,
919 const AElementwiseOperation& a_element_op,
920 const BElementwiseOperation& b_element_op,
921 const CDEElementwiseOperation& cde_element_op) override
922 {
923 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
924 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
925 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
926 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
927 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
928 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
929 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
930 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
931 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
932 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
933 std::array<index_t, NDimSpatial> input_left_pads_i32;
934 std::array<index_t, NDimSpatial> input_right_pads_i32;
935
936 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
937 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
938 array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
939 array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
940 for(index_t d = 0; d < NumDTensor; d++)
941 {
942 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
943 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
944 }
945 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
946 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
947 array_convert(conv_filter_strides_i32, conv_filter_strides);
948 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
949 array_convert(input_left_pads_i32, input_left_pads);
950 array_convert(input_right_pads_i32, input_right_pads);
951
952 return std::make_unique<Argument>(p_a,
953 p_b,
954 p_ds,
955 p_e,
956 a_g_n_c_wis_lengths_i32,
957 a_g_n_c_wis_strides_i32,
958 b_g_k_c_xs_lengths_i32,
959 b_g_k_c_xs_strides_i32,
960 ds_g_n_k_wos_lengths_i32,
961 ds_g_n_k_wos_strides_i32,
962 e_g_n_k_wos_lengths_i32,
963 e_g_n_k_wos_strides_i32,
964 conv_filter_strides_i32,
965 conv_filter_dilations_i32,
966 input_left_pads_i32,
967 input_right_pads_i32,
968 1,
969 1,
970 a_element_op,
971 b_element_op,
972 cde_element_op);
973 }
974
975 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
976 {
977 return std::make_unique<Invoker>(Invoker{});
978 }
979
980 std::string GetTypeString() const override
981 {
982 auto str = std::stringstream();
983
984 // clang-format off
985 str << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
986 << "<"
987 << BlockSize << ", "
988 << MPerBlock << ", "
989 << NPerBlock << ", "
990 << KPerBlock << ", "
991 << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
992 << K1 << ", "
993 << MPerWmma << ", "
994 << NPerWmma << ", "
995 << MRepeat << ", "
996 << NRepeat
997 << ">"
998 << " AEnableLds: "
999 << AEnableLds << ", "
1000 << "BEnableLds: "
1001 << BEnableLds << ", "
1002 << "ABlockTransferSrcScalarPerVector: "
1003 << ABlockTransferSrcScalarPerVector << ", "
1004 << "BBlockTransferSrcScalarPerVector: "
1005 << BBlockTransferSrcScalarPerVector;
1006 // clang-format on
1007
1008 return str.str();
1009 }
1010
1011#ifdef CK_EXPERIMENTAL_BUILDER
1012 std::string GetInstanceString() const override
1013 {
1014 static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
1015 "Specialization of instance_traits not found. Please check that a "
1016 "specialization exists in file "
1017 "ck_tile/builder/reflect/"
1018 "instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp "
1019 "for the given template parameters.");
1020 return ck_tile::reflect::instance_string<DeviceOp>();
1021 }
1022#endif
1023};
1024
1025} // namespace device
1026} // namespace tensor_operation
1027} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_grouped_conv_multiple_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc, const BGridDesc_BK0_N_BK1 b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:40
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition utility/type_convert.hpp:2466
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:326
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/sequence.hpp:43
Definition functional2.hpp:33
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:25
Definition device_base.hpp:197
virtual std::string GetInstanceString() const
Definition device_base.hpp:230
Grouped Convolution Forward.
Definition device_grouped_conv_fwd_multiple_abd.hpp:73
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:327
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:482
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:485
BGridDesc b_grid_desc_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:462
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:483
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:489
void Print() const
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:436
GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:469
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > compute_ptr_offset_of_batch_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:472
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:484
GridwiseOp::DsGridPointer p_ds_grid_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:449
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:480
const BDataType * p_b_grid_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:448
CDEElementwiseOperation cde_element_op_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:477
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:486
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:481
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:490
AGridDesc a_grid_desc_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:461
EDataType * p_e_grid_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:450
Argument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, index_t M01, index_t N01, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:328
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:491
index_t num_group_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:453
EGridDesc_M_N e_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:458
const ADataType * p_a_grid_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:447
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:457
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:487
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:488
GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:466
GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:464
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:475
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:476
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:455
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:496
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:572
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:497
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:499
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:111
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:236
static constexpr auto I0
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:116
static constexpr auto I3
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:119
decltype(DeviceOp::MakeBGridDescriptor< BLayout >(dummy_conv_to_gemm_transformer)) BGridDesc
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:262
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization > ConvToGemmFwdTransformer
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:142
static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:148
static constexpr auto I2
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:118
static constexpr auto K1Number
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:124
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:735
static constexpr auto AEnableLds_auto
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:130
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:975
static constexpr auto MWaves
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:126
static constexpr auto BEnableLds_auto
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:131
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:901
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))> DsGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:264
static constexpr auto BEnableLds
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:139
static constexpr auto NWaves
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:127
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< ELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:266
static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:259
static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:192
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:579
decltype(DeviceOp::MakeAGridDescriptor< ALayout >(dummy_conv_to_gemm_transformer)) AGridDesc
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:260
static constexpr auto matrix_padder
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:144
static constexpr auto I1
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:117
static constexpr auto AEnableLds_manu
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:134
static constexpr auto I4
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:120
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:856
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:780
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:247
static constexpr auto I5
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:121
static constexpr auto I6
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:122
static constexpr auto WmmaK
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:128
static constexpr auto AEnableLds
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:137
static constexpr auto BEnableLds_manu
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:135
static auto MakeInvoker()
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:854
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:980
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:730
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle DeviceOp
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:112
GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, AEnableLds, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BEnableLds, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumGemmKPrefetchStage, LoopSched, PipelineVer > GridwiseOp
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:270
static constexpr index_t NumDTensor
Definition device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp:114
Definition matrix_padder.hpp:180