Skip to content

Commit c9e9e91

Browse files
committed
[CUB] Refactor DeviceSelect::UniqueByKey to always take an environment
We want to be able to pass tunings even to the APIs that currently only take a stream. Refactor so that we can pass an arbitrary environment to those APIs that take user provided memory
1 parent e1b31be commit c9e9e91

2 files changed

Lines changed: 303 additions & 66 deletions

File tree

cub/cub/device/device_select.cuh

Lines changed: 82 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,17 +1893,18 @@ struct DeviceSelect
18931893
//!
18941894
//! @param[in] env
18951895
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
1896-
template <typename KeyInputIteratorT,
1897-
typename ValueInputIteratorT,
1898-
typename KeyOutputIteratorT,
1899-
typename ValueOutputIteratorT,
1900-
typename NumSelectedIteratorT,
1901-
typename NumItemsT,
1902-
typename EqualityOpT,
1903-
typename EnvT = ::cuda::std::execution::env<>,
1904-
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>&& ::cuda::std::
1905-
indirect_binary_predicate<EqualityOpT, KeyInputIteratorT, KeyInputIteratorT>,
1906-
int> = 0>
1896+
template <
1897+
typename KeyInputIteratorT,
1898+
typename ValueInputIteratorT,
1899+
typename KeyOutputIteratorT,
1900+
typename ValueOutputIteratorT,
1901+
typename NumSelectedIteratorT,
1902+
typename NumItemsT,
1903+
typename EqualityOpT,
1904+
typename EnvT = ::cuda::std::execution::env<>,
1905+
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>, int> = 0,
1906+
::cuda::std::enable_if_t<::cuda::std::indirect_binary_predicate<EqualityOpT, KeyInputIteratorT, KeyInputIteratorT>,
1907+
int> = 0>
19071908
[[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t UniqueByKey(
19081909
KeyInputIteratorT d_keys_in,
19091910
ValueInputIteratorT d_values_in,
@@ -1912,7 +1913,7 @@ struct DeviceSelect
19121913
NumSelectedIteratorT d_num_selected_out,
19131914
NumItemsT num_items,
19141915
EqualityOpT equality_op,
1915-
EnvT env = {})
1916+
const EnvT& env = {})
19161917
{
19171918
_CCCL_NVTX_RANGE_SCOPE("cub::DeviceSelect::UniqueByKey");
19181919

@@ -2021,18 +2022,18 @@ struct DeviceSelect
20212022
typename ValueOutputIteratorT,
20222023
typename NumSelectedIteratorT,
20232024
typename NumItemsT,
2024-
typename EnvT = ::cuda::std::execution::env<>,
2025-
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>
2026-
&& !::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>,
2027-
int> = 0>
2025+
typename EnvT = ::cuda::std::execution::env<>,
2026+
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>, int> = 0,
2027+
::cuda::std::enable_if_t<!::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>, int> =
2028+
0>
20282029
[[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t UniqueByKey(
20292030
KeyInputIteratorT d_keys_in,
20302031
ValueInputIteratorT d_values_in,
20312032
KeyOutputIteratorT d_keys_out,
20322033
ValueOutputIteratorT d_values_out,
20332034
NumSelectedIteratorT d_num_selected_out,
20342035
NumItemsT num_items,
2035-
EnvT env = {})
2036+
const EnvT& env = {})
20362037
{
20372038
return UniqueByKey(
20382039
d_keys_in, d_values_in, d_keys_out, d_values_out, d_num_selected_out, num_items, ::cuda::std::equal_to<>{}, env);
@@ -2562,6 +2563,9 @@ struct DeviceSelect
25622563
//! @tparam EqualityOpT
25632564
//! **[inferred]** Type of equality_op
25642565
//!
2566+
//! @tparam EnvT
2567+
//! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
2568+
//!
25652569
//! @param[in] d_temp_storage
25662570
//! @devicestorage
25672571
//!
@@ -2589,48 +2593,54 @@ struct DeviceSelect
25892593
//! @param[in] equality_op
25902594
//! Binary predicate to determine equality
25912595
//!
2592-
//! @param[in] stream
2593-
//! @rst
2594-
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
2595-
//! @endrst
2596-
template <typename KeyInputIteratorT,
2597-
typename ValueInputIteratorT,
2598-
typename KeyOutputIteratorT,
2599-
typename ValueOutputIteratorT,
2600-
typename NumSelectedIteratorT,
2601-
typename NumItemsT,
2602-
typename EqualityOpT>
2603-
CUB_RUNTIME_FUNCTION __forceinline__ static //
2604-
::cuda::std::enable_if_t< //
2605-
!::cuda::std::is_convertible_v<EqualityOpT, cudaStream_t>, //
2606-
cudaError_t>
2607-
UniqueByKey(
2608-
void* d_temp_storage,
2609-
size_t& temp_storage_bytes,
2610-
KeyInputIteratorT d_keys_in,
2611-
ValueInputIteratorT d_values_in,
2612-
KeyOutputIteratorT d_keys_out,
2613-
ValueOutputIteratorT d_values_out,
2614-
NumSelectedIteratorT d_num_selected_out,
2615-
NumItemsT num_items,
2616-
EqualityOpT equality_op,
2617-
cudaStream_t stream = nullptr)
2596+
//! @param[in] env
2597+
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
2598+
template <
2599+
typename KeyInputIteratorT,
2600+
typename ValueInputIteratorT,
2601+
typename KeyOutputIteratorT,
2602+
typename ValueOutputIteratorT,
2603+
typename NumSelectedIteratorT,
2604+
typename NumItemsT,
2605+
typename EqualityOpT,
2606+
typename EnvT = ::cuda::std::execution::env<>,
2607+
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>, int> = 0,
2608+
::cuda::std::enable_if_t<::cuda::std::indirect_binary_predicate<EqualityOpT, KeyInputIteratorT, KeyInputIteratorT>,
2609+
int> = 0>
2610+
CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t UniqueByKey(
2611+
void* d_temp_storage,
2612+
size_t& temp_storage_bytes,
2613+
KeyInputIteratorT d_keys_in,
2614+
ValueInputIteratorT d_values_in,
2615+
KeyOutputIteratorT d_keys_out,
2616+
ValueOutputIteratorT d_values_out,
2617+
NumSelectedIteratorT d_num_selected_out,
2618+
NumItemsT num_items,
2619+
EqualityOpT equality_op,
2620+
const EnvT& env = {})
26182621
{
26192622
_CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceSelect::UniqueByKey");
26202623

26212624
using offset_t = detail::choose_offset_t<NumItemsT>;
26222625

2623-
return detail::unique_by_key::dispatch(
2624-
d_temp_storage,
2625-
temp_storage_bytes,
2626-
d_keys_in,
2627-
d_values_in,
2628-
d_keys_out,
2629-
d_values_out,
2630-
d_num_selected_out,
2631-
equality_op,
2632-
static_cast<offset_t>(num_items),
2633-
stream);
2626+
using default_policy_selector =
2627+
detail::unique_by_key::policy_selector_from_types<detail::it_value_t<KeyInputIteratorT>,
2628+
detail::it_value_t<ValueInputIteratorT>>;
2629+
return detail::dispatch_with_env_and_tuning<default_policy_selector>(
2630+
d_temp_storage, temp_storage_bytes, env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) {
2631+
return detail::unique_by_key::dispatch(
2632+
storage,
2633+
bytes,
2634+
d_keys_in,
2635+
d_values_in,
2636+
d_keys_out,
2637+
d_values_out,
2638+
d_num_selected_out,
2639+
equality_op,
2640+
static_cast<offset_t>(num_items),
2641+
stream,
2642+
policy_selector);
2643+
});
26342644
}
26352645

26362646
//! @rst
@@ -2716,6 +2726,9 @@ struct DeviceSelect
27162726
//! @tparam NumItemsT
27172727
//! **[inferred]** Type of num_items
27182728
//!
2729+
//! @tparam EnvT
2730+
//! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
2731+
//!
27192732
//! @param[in] d_temp_storage
27202733
//! @devicestorage
27212734
//!
@@ -2740,16 +2753,19 @@ struct DeviceSelect
27402753
//! @param[in] num_items
27412754
//! Total number of input items (i.e., length of `d_keys_in` or `d_values_in`)
27422755
//!
2743-
//! @param[in] stream
2744-
//! @rst
2745-
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
2746-
//! @endrst
2747-
template <typename KeyInputIteratorT,
2748-
typename ValueInputIteratorT,
2749-
typename KeyOutputIteratorT,
2750-
typename ValueOutputIteratorT,
2751-
typename NumSelectedIteratorT,
2752-
typename NumItemsT>
2756+
//! @param[in] env
2757+
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
2758+
template <
2759+
typename KeyInputIteratorT,
2760+
typename ValueInputIteratorT,
2761+
typename KeyOutputIteratorT,
2762+
typename ValueOutputIteratorT,
2763+
typename NumSelectedIteratorT,
2764+
typename NumItemsT,
2765+
typename EnvT = ::cuda::std::execution::env<>,
2766+
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>, int> = 0,
2767+
::cuda::std::enable_if_t<!::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>, int> =
2768+
0>
27532769
CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t UniqueByKey(
27542770
void* d_temp_storage,
27552771
size_t& temp_storage_bytes,
@@ -2759,7 +2775,7 @@ struct DeviceSelect
27592775
ValueOutputIteratorT d_values_out,
27602776
NumSelectedIteratorT d_num_selected_out,
27612777
NumItemsT num_items,
2762-
cudaStream_t stream = nullptr)
2778+
const EnvT& env = {})
27632779
{
27642780
return UniqueByKey(
27652781
d_temp_storage,
@@ -2771,7 +2787,7 @@ struct DeviceSelect
27712787
d_num_selected_out,
27722788
num_items,
27732789
::cuda::std::equal_to<>{},
2774-
stream);
2790+
env);
27752791
}
27762792
};
27772793

0 commit comments

Comments
 (0)