Skip to content

Commit 7051612

Browse files
committed
[CUB] Refactor DeviceSelect::UniqueByKey to always take an environment
We want to be able to pass tunings even to the APIs that currently only take a stream. Refactor so that we can pass an arbitrary environment to those APIs that take user provided memory
1 parent e1b31be commit 7051612

2 files changed

Lines changed: 284 additions & 53 deletions

File tree

cub/cub/device/device_select.cuh

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,10 +1900,8 @@ struct DeviceSelect
19001900
typename NumSelectedIteratorT,
19011901
typename NumItemsT,
19021902
typename EqualityOpT,
1903-
typename EnvT = ::cuda::std::execution::env<>,
1904-
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>&& ::cuda::std::
1905-
indirect_binary_predicate<EqualityOpT, KeyInputIteratorT, KeyInputIteratorT>,
1906-
int> = 0>
1903+
typename EnvT = ::cuda::std::execution::env<>,
1904+
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>, int> = 0>
19071905
[[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t UniqueByKey(
19081906
KeyInputIteratorT d_keys_in,
19091907
ValueInputIteratorT d_values_in,
@@ -1912,7 +1910,7 @@ struct DeviceSelect
19121910
NumSelectedIteratorT d_num_selected_out,
19131911
NumItemsT num_items,
19141912
EqualityOpT equality_op,
1915-
EnvT env = {})
1913+
const EnvT& env = {})
19161914
{
19171915
_CCCL_NVTX_RANGE_SCOPE("cub::DeviceSelect::UniqueByKey");
19181916

@@ -2021,18 +2019,18 @@ struct DeviceSelect
20212019
typename ValueOutputIteratorT,
20222020
typename NumSelectedIteratorT,
20232021
typename NumItemsT,
2024-
typename EnvT = ::cuda::std::execution::env<>,
2025-
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>
2026-
&& !::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>,
2027-
int> = 0>
2022+
typename EnvT = ::cuda::std::execution::env<>,
2023+
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>, int> = 0,
2024+
::cuda::std::enable_if_t<!::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>, int> =
2025+
0>
20282026
[[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t UniqueByKey(
20292027
KeyInputIteratorT d_keys_in,
20302028
ValueInputIteratorT d_values_in,
20312029
KeyOutputIteratorT d_keys_out,
20322030
ValueOutputIteratorT d_values_out,
20332031
NumSelectedIteratorT d_num_selected_out,
20342032
NumItemsT num_items,
2035-
EnvT env = {})
2033+
const EnvT& env = {})
20362034
{
20372035
return UniqueByKey(
20382036
d_keys_in, d_values_in, d_keys_out, d_values_out, d_num_selected_out, num_items, ::cuda::std::equal_to<>{}, env);
@@ -2562,6 +2560,9 @@ struct DeviceSelect
25622560
//! @tparam EqualityOpT
25632561
//! **[inferred]** Type of equality_op
25642562
//!
2563+
//! @tparam EnvT
2564+
//! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
2565+
//!
25652566
//! @param[in] d_temp_storage
25662567
//! @devicestorage
25672568
//!
@@ -2589,48 +2590,51 @@ struct DeviceSelect
25892590
//! @param[in] equality_op
25902591
//! Binary predicate to determine equality
25912592
//!
2592-
//! @param[in] stream
2593-
//! @rst
2594-
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
2595-
//! @endrst
2593+
//! @param[in] env
2594+
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
25962595
template <typename KeyInputIteratorT,
25972596
typename ValueInputIteratorT,
25982597
typename KeyOutputIteratorT,
25992598
typename ValueOutputIteratorT,
26002599
typename NumSelectedIteratorT,
26012600
typename NumItemsT,
2602-
typename EqualityOpT>
2603-
CUB_RUNTIME_FUNCTION __forceinline__ static //
2604-
::cuda::std::enable_if_t< //
2605-
!::cuda::std::is_convertible_v<EqualityOpT, cudaStream_t>, //
2606-
cudaError_t>
2607-
UniqueByKey(
2608-
void* d_temp_storage,
2609-
size_t& temp_storage_bytes,
2610-
KeyInputIteratorT d_keys_in,
2611-
ValueInputIteratorT d_values_in,
2612-
KeyOutputIteratorT d_keys_out,
2613-
ValueOutputIteratorT d_values_out,
2614-
NumSelectedIteratorT d_num_selected_out,
2615-
NumItemsT num_items,
2616-
EqualityOpT equality_op,
2617-
cudaStream_t stream = nullptr)
2601+
typename EqualityOpT,
2602+
typename EnvT = ::cuda::std::execution::env<>,
2603+
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>, int> = 0>
2604+
CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t UniqueByKey(
2605+
void* d_temp_storage,
2606+
size_t& temp_storage_bytes,
2607+
KeyInputIteratorT d_keys_in,
2608+
ValueInputIteratorT d_values_in,
2609+
KeyOutputIteratorT d_keys_out,
2610+
ValueOutputIteratorT d_values_out,
2611+
NumSelectedIteratorT d_num_selected_out,
2612+
NumItemsT num_items,
2613+
EqualityOpT equality_op,
2614+
const EnvT& env = {})
26182615
{
26192616
_CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceSelect::UniqueByKey");
26202617

26212618
using offset_t = detail::choose_offset_t<NumItemsT>;
26222619

2623-
return detail::unique_by_key::dispatch(
2624-
d_temp_storage,
2625-
temp_storage_bytes,
2626-
d_keys_in,
2627-
d_values_in,
2628-
d_keys_out,
2629-
d_values_out,
2630-
d_num_selected_out,
2631-
equality_op,
2632-
static_cast<offset_t>(num_items),
2633-
stream);
2620+
using default_policy_selector =
2621+
detail::unique_by_key::policy_selector_from_types<detail::it_value_t<KeyInputIteratorT>,
2622+
detail::it_value_t<ValueInputIteratorT>>;
2623+
return detail::dispatch_with_env_and_tuning<default_policy_selector>(
2624+
d_temp_storage, temp_storage_bytes, env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) {
2625+
return detail::unique_by_key::dispatch(
2626+
storage,
2627+
bytes,
2628+
d_keys_in,
2629+
d_values_in,
2630+
d_keys_out,
2631+
d_values_out,
2632+
d_num_selected_out,
2633+
equality_op,
2634+
static_cast<offset_t>(num_items),
2635+
stream,
2636+
policy_selector);
2637+
});
26342638
}
26352639

26362640
//! @rst
@@ -2716,6 +2720,9 @@ struct DeviceSelect
27162720
//! @tparam NumItemsT
27172721
//! **[inferred]** Type of num_items
27182722
//!
2723+
//! @tparam EnvT
2724+
//! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
2725+
//!
27192726
//! @param[in] d_temp_storage
27202727
//! @devicestorage
27212728
//!
@@ -2740,16 +2747,19 @@ struct DeviceSelect
27402747
//! @param[in] num_items
27412748
//! Total number of input items (i.e., length of `d_keys_in` or `d_values_in`)
27422749
//!
2743-
//! @param[in] stream
2744-
//! @rst
2745-
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
2746-
//! @endrst
2747-
template <typename KeyInputIteratorT,
2748-
typename ValueInputIteratorT,
2749-
typename KeyOutputIteratorT,
2750-
typename ValueOutputIteratorT,
2751-
typename NumSelectedIteratorT,
2752-
typename NumItemsT>
2750+
//! @param[in] env
2751+
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
2752+
template <
2753+
typename KeyInputIteratorT,
2754+
typename ValueInputIteratorT,
2755+
typename KeyOutputIteratorT,
2756+
typename ValueOutputIteratorT,
2757+
typename NumSelectedIteratorT,
2758+
typename NumItemsT,
2759+
typename EnvT = ::cuda::std::execution::env<>,
2760+
::cuda::std::enable_if_t<::cuda::std::is_integral_v<NumItemsT>, int> = 0,
2761+
::cuda::std::enable_if_t<!::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>, int> =
2762+
0>
27532763
CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t UniqueByKey(
27542764
void* d_temp_storage,
27552765
size_t& temp_storage_bytes,
@@ -2759,7 +2769,7 @@ struct DeviceSelect
27592769
ValueOutputIteratorT d_values_out,
27602770
NumSelectedIteratorT d_num_selected_out,
27612771
NumItemsT num_items,
2762-
cudaStream_t stream = nullptr)
2772+
const EnvT& env = {})
27632773
{
27642774
return UniqueByKey(
27652775
d_temp_storage,
@@ -2771,7 +2781,7 @@ struct DeviceSelect
27712781
d_num_selected_out,
27722782
num_items,
27732783
::cuda::std::equal_to<>{},
2774-
stream);
2784+
env);
27752785
}
27762786
};
27772787

0 commit comments

Comments
 (0)