@@ -1900,10 +1900,8 @@ struct DeviceSelect
19001900 typename NumSelectedIteratorT,
19011901 typename NumItemsT,
19021902 typename EqualityOpT,
1903- typename EnvT = ::cuda::std::execution::env<>,
1904- ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>&& ::cuda::std::
1905- indirect_binary_predicate<EqualityOpT, KeyInputIteratorT, KeyInputIteratorT>,
1906- int > = 0 >
1903+ typename EnvT = ::cuda::std::execution::env<>,
1904+ ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>, int > = 0 >
19071905 [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t UniqueByKey (
19081906 KeyInputIteratorT d_keys_in,
19091907 ValueInputIteratorT d_values_in,
@@ -1912,7 +1910,7 @@ struct DeviceSelect
19121910 NumSelectedIteratorT d_num_selected_out,
19131911 NumItemsT num_items,
19141912 EqualityOpT equality_op,
1915- EnvT env = {})
1913+ const EnvT& env = {})
19161914 {
19171915 _CCCL_NVTX_RANGE_SCOPE (" cub::DeviceSelect::UniqueByKey" );
19181916
@@ -2021,18 +2019,18 @@ struct DeviceSelect
20212019 typename ValueOutputIteratorT,
20222020 typename NumSelectedIteratorT,
20232021 typename NumItemsT,
2024- typename EnvT = ::cuda::std::execution::env<>,
2025- ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>
2026- && !::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>,
2027- int > = 0 >
2022+ typename EnvT = ::cuda::std::execution::env<>,
2023+ ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>, int > = 0 ,
2024+ ::cuda::std:: enable_if_t < !::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>, int > =
2025+ 0 >
20282026 [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t UniqueByKey (
20292027 KeyInputIteratorT d_keys_in,
20302028 ValueInputIteratorT d_values_in,
20312029 KeyOutputIteratorT d_keys_out,
20322030 ValueOutputIteratorT d_values_out,
20332031 NumSelectedIteratorT d_num_selected_out,
20342032 NumItemsT num_items,
2035- EnvT env = {})
2033+ const EnvT& env = {})
20362034 {
20372035 return UniqueByKey (
20382036 d_keys_in, d_values_in, d_keys_out, d_values_out, d_num_selected_out, num_items, ::cuda::std::equal_to<>{}, env);
@@ -2562,6 +2560,9 @@ struct DeviceSelect
25622560 // ! @tparam EqualityOpT
25632561 // ! **[inferred]** Type of equality_op
25642562 // !
2563+ // ! @tparam EnvT
2564+ // ! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
2565+ // !
25652566 // ! @param[in] d_temp_storage
25662567 // ! @devicestorage
25672568 // !
@@ -2589,48 +2590,51 @@ struct DeviceSelect
25892590 // ! @param[in] equality_op
25902591 // ! Binary predicate to determine equality
25912592 // !
2592- // ! @param[in] stream
2593- // ! @rst
2594- // ! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
2595- // ! @endrst
2593+ // ! @param[in] env
2594+ // ! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
25962595 template <typename KeyInputIteratorT,
25972596 typename ValueInputIteratorT,
25982597 typename KeyOutputIteratorT,
25992598 typename ValueOutputIteratorT,
26002599 typename NumSelectedIteratorT,
26012600 typename NumItemsT,
2602- typename EqualityOpT>
2603- CUB_RUNTIME_FUNCTION __forceinline__ static //
2604- ::cuda::std::enable_if_t < //
2605- !::cuda::std::is_convertible_v<EqualityOpT, cudaStream_t>, //
2606- cudaError_t>
2607- UniqueByKey (
2608- void * d_temp_storage,
2609- size_t & temp_storage_bytes,
2610- KeyInputIteratorT d_keys_in,
2611- ValueInputIteratorT d_values_in,
2612- KeyOutputIteratorT d_keys_out,
2613- ValueOutputIteratorT d_values_out,
2614- NumSelectedIteratorT d_num_selected_out,
2615- NumItemsT num_items,
2616- EqualityOpT equality_op,
2617- cudaStream_t stream = nullptr )
2601+ typename EqualityOpT,
2602+ typename EnvT = ::cuda::std::execution::env<>,
2603+ ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>, int > = 0 >
2604+ CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t UniqueByKey (
2605+ void * d_temp_storage,
2606+ size_t & temp_storage_bytes,
2607+ KeyInputIteratorT d_keys_in,
2608+ ValueInputIteratorT d_values_in,
2609+ KeyOutputIteratorT d_keys_out,
2610+ ValueOutputIteratorT d_values_out,
2611+ NumSelectedIteratorT d_num_selected_out,
2612+ NumItemsT num_items,
2613+ EqualityOpT equality_op,
2614+ const EnvT& env = {})
26182615 {
26192616 _CCCL_NVTX_RANGE_SCOPE_IF (d_temp_storage, " cub::DeviceSelect::UniqueByKey" );
26202617
26212618 using offset_t = detail::choose_offset_t <NumItemsT>;
26222619
2623- return detail::unique_by_key::dispatch (
2624- d_temp_storage,
2625- temp_storage_bytes,
2626- d_keys_in,
2627- d_values_in,
2628- d_keys_out,
2629- d_values_out,
2630- d_num_selected_out,
2631- equality_op,
2632- static_cast <offset_t >(num_items),
2633- stream);
2620+ using default_policy_selector =
2621+ detail::unique_by_key::policy_selector_from_types<detail::it_value_t <KeyInputIteratorT>,
2622+ detail::it_value_t <ValueInputIteratorT>>;
2623+ return detail::dispatch_with_env_and_tuning<default_policy_selector>(
2624+ d_temp_storage, temp_storage_bytes, env, [&](auto policy_selector, void * storage, size_t & bytes, auto stream) {
2625+ return detail::unique_by_key::dispatch (
2626+ storage,
2627+ bytes,
2628+ d_keys_in,
2629+ d_values_in,
2630+ d_keys_out,
2631+ d_values_out,
2632+ d_num_selected_out,
2633+ equality_op,
2634+ static_cast <offset_t >(num_items),
2635+ stream,
2636+ policy_selector);
2637+ });
26342638 }
26352639
26362640 // ! @rst
@@ -2716,6 +2720,9 @@ struct DeviceSelect
27162720 // ! @tparam NumItemsT
27172721 // ! **[inferred]** Type of num_items
27182722 // !
2723+ // ! @tparam EnvT
2724+ // ! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
2725+ // !
27192726 // ! @param[in] d_temp_storage
27202727 // ! @devicestorage
27212728 // !
@@ -2740,16 +2747,19 @@ struct DeviceSelect
27402747 // ! @param[in] num_items
27412748 // ! Total number of input items (i.e., length of `d_keys_in` or `d_values_in`)
27422749 // !
2743- // ! @param[in] stream
2744- // ! @rst
2745- // ! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
2746- // ! @endrst
2747- template <typename KeyInputIteratorT,
2748- typename ValueInputIteratorT,
2749- typename KeyOutputIteratorT,
2750- typename ValueOutputIteratorT,
2751- typename NumSelectedIteratorT,
2752- typename NumItemsT>
2750+ // ! @param[in] env
2751+ // ! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
2752+ template <
2753+ typename KeyInputIteratorT,
2754+ typename ValueInputIteratorT,
2755+ typename KeyOutputIteratorT,
2756+ typename ValueOutputIteratorT,
2757+ typename NumSelectedIteratorT,
2758+ typename NumItemsT,
2759+ typename EnvT = ::cuda::std::execution::env<>,
2760+ ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>, int > = 0 ,
2761+ ::cuda::std::enable_if_t <!::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>, int > =
2762+ 0 >
27532763 CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t UniqueByKey (
27542764 void * d_temp_storage,
27552765 size_t & temp_storage_bytes,
@@ -2759,7 +2769,7 @@ struct DeviceSelect
27592769 ValueOutputIteratorT d_values_out,
27602770 NumSelectedIteratorT d_num_selected_out,
27612771 NumItemsT num_items,
2762- cudaStream_t stream = nullptr )
2772+ const EnvT& env = {} )
27632773 {
27642774 return UniqueByKey (
27652775 d_temp_storage,
@@ -2771,7 +2781,7 @@ struct DeviceSelect
27712781 d_num_selected_out,
27722782 num_items,
27732783 ::cuda::std::equal_to<>{},
2774- stream );
2784+ env );
27752785 }
27762786};
27772787
0 commit comments