threadwise_tensor_slice_transfer_util.hpp Source File

threadwise_tensor_slice_transfer_util.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer_util.hpp Source File
threadwise_tensor_slice_transfer_util.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3#pragma once
4
5namespace ck {
6
7// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
8// and sometimes useless instructions:
9// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
10// instead
11// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
12// tensor coordinate instead
13// 3. Don't use a pointer to VGPR buffer, use vector instead
14
15namespace detail {
16// TODO: How to fix this? It uses an struct instead of lambda because lambda
17// doesn't have constructor
18template <index_t VectorDim, index_t ScalarPerVector>
20{
21 __host__ __device__ constexpr auto operator()(index_t i) const
22 {
23 return (i == VectorDim) ? ScalarPerVector : 1;
24 }
25};
26
27template <index_t VectorDim>
29{
30 __host__ __device__ constexpr auto operator()(index_t i) const
31 {
32 return (i == VectorDim) ? 1 : 0;
33 }
34};
35
36// TODO: How to fix this? It uses an struct instead of lambda because lambda
37// doesn't have constructor
38template <index_t SrcVectorDim,
39 index_t SrcScalarPerVector,
40 index_t DstVectorDim,
41 index_t DstScalarPerVector>
43{
44 __host__ __device__ constexpr auto operator()(index_t i) const
45 {
46 if(i == SrcVectorDim && i == DstVectorDim)
47 {
48 return math::lcm(SrcScalarPerVector, DstScalarPerVector);
49 }
50 else if(i == SrcVectorDim)
51 {
52 return SrcScalarPerVector;
53 }
54 else if(i == DstVectorDim)
55 {
56 return DstScalarPerVector;
57 }
58 else
59 {
60 return 1;
61 }
62 }
63};
64
65template <index_t WaveNum, index_t nDim>
67{
68 __host__ __device__ constexpr auto operator()(index_t i) const
69 {
70 if((nDim - i) == 3)
71 return WaveNum;
72 else
73 return 1;
74 }
75};
76
77} // namespace detail
78
79} // namespace ck
Definition threadwise_tensor_slice_transfer_util.hpp:15
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition threadwise_tensor_slice_transfer_util.hpp:43
__host__ __device__ constexpr auto operator()(index_t i) const
Definition threadwise_tensor_slice_transfer_util.hpp:44
Definition threadwise_tensor_slice_transfer_util.hpp:20
__host__ __device__ constexpr auto operator()(index_t i) const
Definition threadwise_tensor_slice_transfer_util.hpp:21
Definition threadwise_tensor_slice_transfer_util.hpp:29
__host__ __device__ constexpr auto operator()(index_t i) const
Definition threadwise_tensor_slice_transfer_util.hpp:30
Definition threadwise_tensor_slice_transfer_util.hpp:67
__host__ __device__ constexpr auto operator()(index_t i) const
Definition threadwise_tensor_slice_transfer_util.hpp:68