reduction_functions_threadwise.hpp Source File

reduction_functions_threadwise.hpp Source File#

Composable Kernel: reduction_functions_threadwise.hpp Source File
reduction_functions_threadwise.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Assume
11// 1) SrcDesc is known at compile-time
12// 2) DstDesc is known at compile-time
13// 3) SrcBuffer is static buffer
14// 4) DstBuffer is static buffer
15template <typename AccDataType,
16 typename SrcThreadDesc_M_K,
17 typename DstThreadDesc_M,
18 typename OpReduce,
19 bool PropagateNan,
20 typename Accumulation =
23{
24 static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
25 static constexpr auto dst_thread_desc_m = DstThreadDesc_M{};
26
27 static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
28 static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
29 static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
30
31 static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
32
33 using Op = OpReduce;
34
35 template <typename SrcBufferType, typename DstBufferType>
36 __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
37 {
38 static_for<0, src_length_m, 1>{}([&](auto iM) {
39 constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM));
40
41 static_for<0, src_length_k, 1>{}([&](auto iK) {
42 constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
43
44 Accumulation::Calculate(dst_buf(Number<out_offset>{}), src_buf[Number<offset>{}]);
45 });
46 });
47 };
48};
49
50// Assume
51// 1) SrcDesc is known at compile-time
52// 2) DstDesc is known at compile-time
53// 3) SrcBuffer is static buffer
54// 4) DstBuffer is static buffer
55template <
56 typename AccDataType,
57 typename IndexDataType,
58 typename SrcThreadDesc_M_K,
59 typename DstThreadDesc_M,
60 typename OpReduce,
61 bool PropagateNan,
62 typename Accumulation =
63 detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
65{
66 static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
67 static constexpr auto dst_thread_desc_m = DstThreadDesc_M{};
68
69 static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
70 static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
71 static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
72
73 static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
74
75 template <typename SrcValueBufferType,
76 typename SrcIndexBufferType,
77 typename DstValueBufferType,
78 typename DstIndexBufferType>
79 __device__ static void Reduce(const SrcValueBufferType& src_val_buf,
80 const SrcIndexBufferType& src_idx_buf,
81 DstValueBufferType& dst_val_buf,
82 DstIndexBufferType& dst_idx_buf)
83 {
84 static_for<0, src_length_m, 1>{}([&](auto iM) {
85 constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM));
86
87 static_for<0, src_length_k, 1>{}([&](auto iK) {
88 constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
89
90 Accumulation::Calculate(dst_val_buf(Number<out_offset>{}),
91 src_val_buf[Number<offset>{}],
92 dst_idx_buf(Number<out_offset>{}),
93 src_idx_buf[Number<offset>{}]);
94 });
95 });
96 };
97};
98
99} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
Definition reduction_functions_threadwise.hpp:23
static constexpr auto dst_thread_desc_m
Definition reduction_functions_threadwise.hpp:25
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition reduction_functions_threadwise.hpp:36
static constexpr auto src_thread_desc_m_k
Definition reduction_functions_threadwise.hpp:24
Definition reduction_functions_threadwise.hpp:65
static constexpr auto src_thread_desc_m_k
Definition reduction_functions_threadwise.hpp:66
static constexpr auto src_length_m
Definition reduction_functions_threadwise.hpp:69
static __device__ void Reduce(const SrcValueBufferType &src_val_buf, const SrcIndexBufferType &src_idx_buf, DstValueBufferType &dst_val_buf, DstIndexBufferType &dst_idx_buf)
Definition reduction_functions_threadwise.hpp:79
static constexpr auto src_length_k
Definition reduction_functions_threadwise.hpp:70
static constexpr auto dst_length_m
Definition reduction_functions_threadwise.hpp:71
static constexpr auto dst_thread_desc_m
Definition reduction_functions_threadwise.hpp:67
Definition reduction_functions_accumulate.hpp:28
Definition functional2.hpp:33