@@ -1893,17 +1893,18 @@ struct DeviceSelect
18931893 // !
18941894 // ! @param[in] env
18951895 // ! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
1896- template <typename KeyInputIteratorT,
1897- typename ValueInputIteratorT,
1898- typename KeyOutputIteratorT,
1899- typename ValueOutputIteratorT,
1900- typename NumSelectedIteratorT,
1901- typename NumItemsT,
1902- typename EqualityOpT,
1903- typename EnvT = ::cuda::std::execution::env<>,
1904- ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>&& ::cuda::std::
1905- indirect_binary_predicate<EqualityOpT, KeyInputIteratorT, KeyInputIteratorT>,
1906- int > = 0 >
1896+ template <
1897+ typename KeyInputIteratorT,
1898+ typename ValueInputIteratorT,
1899+ typename KeyOutputIteratorT,
1900+ typename ValueOutputIteratorT,
1901+ typename NumSelectedIteratorT,
1902+ typename NumItemsT,
1903+ typename EqualityOpT,
1904+ typename EnvT = ::cuda::std::execution::env<>,
1905+ ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>, int > = 0 ,
1906+ ::cuda::std::enable_if_t <::cuda::std::indirect_binary_predicate<EqualityOpT, KeyInputIteratorT, KeyInputIteratorT>,
1907+ int > = 0 >
19071908 [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t UniqueByKey (
19081909 KeyInputIteratorT d_keys_in,
19091910 ValueInputIteratorT d_values_in,
@@ -1912,7 +1913,7 @@ struct DeviceSelect
19121913 NumSelectedIteratorT d_num_selected_out,
19131914 NumItemsT num_items,
19141915 EqualityOpT equality_op,
1915- EnvT env = {})
1916+ const EnvT& env = {})
19161917 {
19171918 _CCCL_NVTX_RANGE_SCOPE (" cub::DeviceSelect::UniqueByKey" );
19181919
@@ -2021,18 +2022,18 @@ struct DeviceSelect
20212022 typename ValueOutputIteratorT,
20222023 typename NumSelectedIteratorT,
20232024 typename NumItemsT,
2024- typename EnvT = ::cuda::std::execution::env<>,
2025- ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>
2026- && !::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>,
2027- int > = 0 >
2025+ typename EnvT = ::cuda::std::execution::env<>,
2026+ ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>, int > = 0 ,
2027+ ::cuda::std:: enable_if_t < !::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>, int > =
2028+ 0 >
20282029 [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t UniqueByKey (
20292030 KeyInputIteratorT d_keys_in,
20302031 ValueInputIteratorT d_values_in,
20312032 KeyOutputIteratorT d_keys_out,
20322033 ValueOutputIteratorT d_values_out,
20332034 NumSelectedIteratorT d_num_selected_out,
20342035 NumItemsT num_items,
2035- EnvT env = {})
2036+ const EnvT& env = {})
20362037 {
20372038 return UniqueByKey (
20382039 d_keys_in, d_values_in, d_keys_out, d_values_out, d_num_selected_out, num_items, ::cuda::std::equal_to<>{}, env);
@@ -2562,6 +2563,9 @@ struct DeviceSelect
25622563 // ! @tparam EqualityOpT
25632564 // ! **[inferred]** Type of equality_op
25642565 // !
2566+ // ! @tparam EnvT
2567+ // ! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
2568+ // !
25652569 // ! @param[in] d_temp_storage
25662570 // ! @devicestorage
25672571 // !
@@ -2589,48 +2593,54 @@ struct DeviceSelect
25892593 // ! @param[in] equality_op
25902594 // ! Binary predicate to determine equality
25912595 // !
2592- // ! @param[in] stream
2593- // ! @rst
2594- // ! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
2595- // ! @endrst
2596- template <typename KeyInputIteratorT,
2597- typename ValueInputIteratorT,
2598- typename KeyOutputIteratorT,
2599- typename ValueOutputIteratorT,
2600- typename NumSelectedIteratorT,
2601- typename NumItemsT,
2602- typename EqualityOpT>
2603- CUB_RUNTIME_FUNCTION __forceinline__ static //
2604- ::cuda::std::enable_if_t < //
2605- !::cuda::std::is_convertible_v<EqualityOpT, cudaStream_t>, //
2606- cudaError_t>
2607- UniqueByKey (
2608- void * d_temp_storage,
2609- size_t & temp_storage_bytes,
2610- KeyInputIteratorT d_keys_in,
2611- ValueInputIteratorT d_values_in,
2612- KeyOutputIteratorT d_keys_out,
2613- ValueOutputIteratorT d_values_out,
2614- NumSelectedIteratorT d_num_selected_out,
2615- NumItemsT num_items,
2616- EqualityOpT equality_op,
2617- cudaStream_t stream = nullptr )
2596+ // ! @param[in] env
2597+ // ! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
2598+ template <
2599+ typename KeyInputIteratorT,
2600+ typename ValueInputIteratorT,
2601+ typename KeyOutputIteratorT,
2602+ typename ValueOutputIteratorT,
2603+ typename NumSelectedIteratorT,
2604+ typename NumItemsT,
2605+ typename EqualityOpT,
2606+ typename EnvT = ::cuda::std::execution::env<>,
2607+ ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>, int > = 0 ,
2608+ ::cuda::std::enable_if_t <::cuda::std::indirect_binary_predicate<EqualityOpT, KeyInputIteratorT, KeyInputIteratorT>,
2609+ int > = 0 >
2610+ CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t UniqueByKey (
2611+ void * d_temp_storage,
2612+ size_t & temp_storage_bytes,
2613+ KeyInputIteratorT d_keys_in,
2614+ ValueInputIteratorT d_values_in,
2615+ KeyOutputIteratorT d_keys_out,
2616+ ValueOutputIteratorT d_values_out,
2617+ NumSelectedIteratorT d_num_selected_out,
2618+ NumItemsT num_items,
2619+ EqualityOpT equality_op,
2620+ const EnvT& env = {})
26182621 {
26192622 _CCCL_NVTX_RANGE_SCOPE_IF (d_temp_storage, " cub::DeviceSelect::UniqueByKey" );
26202623
26212624 using offset_t = detail::choose_offset_t <NumItemsT>;
26222625
2623- return detail::unique_by_key::dispatch (
2624- d_temp_storage,
2625- temp_storage_bytes,
2626- d_keys_in,
2627- d_values_in,
2628- d_keys_out,
2629- d_values_out,
2630- d_num_selected_out,
2631- equality_op,
2632- static_cast <offset_t >(num_items),
2633- stream);
2626+ using default_policy_selector =
2627+ detail::unique_by_key::policy_selector_from_types<detail::it_value_t <KeyInputIteratorT>,
2628+ detail::it_value_t <ValueInputIteratorT>>;
2629+ return detail::dispatch_with_env_and_tuning<default_policy_selector>(
2630+ d_temp_storage, temp_storage_bytes, env, [&](auto policy_selector, void * storage, size_t & bytes, auto stream) {
2631+ return detail::unique_by_key::dispatch (
2632+ storage,
2633+ bytes,
2634+ d_keys_in,
2635+ d_values_in,
2636+ d_keys_out,
2637+ d_values_out,
2638+ d_num_selected_out,
2639+ equality_op,
2640+ static_cast <offset_t >(num_items),
2641+ stream,
2642+ policy_selector);
2643+ });
26342644 }
26352645
26362646 // ! @rst
@@ -2716,6 +2726,9 @@ struct DeviceSelect
27162726 // ! @tparam NumItemsT
27172727 // ! **[inferred]** Type of num_items
27182728 // !
2729+ // ! @tparam EnvT
2730+ // ! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
2731+ // !
27192732 // ! @param[in] d_temp_storage
27202733 // ! @devicestorage
27212734 // !
@@ -2740,16 +2753,19 @@ struct DeviceSelect
27402753 // ! @param[in] num_items
27412754 // ! Total number of input items (i.e., length of `d_keys_in` or `d_values_in`)
27422755 // !
2743- // ! @param[in] stream
2744- // ! @rst
2745- // ! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
2746- // ! @endrst
2747- template <typename KeyInputIteratorT,
2748- typename ValueInputIteratorT,
2749- typename KeyOutputIteratorT,
2750- typename ValueOutputIteratorT,
2751- typename NumSelectedIteratorT,
2752- typename NumItemsT>
2756+ // ! @param[in] env
2757+ // ! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
2758+ template <
2759+ typename KeyInputIteratorT,
2760+ typename ValueInputIteratorT,
2761+ typename KeyOutputIteratorT,
2762+ typename ValueOutputIteratorT,
2763+ typename NumSelectedIteratorT,
2764+ typename NumItemsT,
2765+ typename EnvT = ::cuda::std::execution::env<>,
2766+ ::cuda::std::enable_if_t <::cuda::std::is_integral_v<NumItemsT>, int > = 0 ,
2767+ ::cuda::std::enable_if_t <!::cuda::std::indirect_binary_predicate<EnvT, KeyInputIteratorT, KeyInputIteratorT>, int > =
2768+ 0 >
27532769 CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t UniqueByKey (
27542770 void * d_temp_storage,
27552771 size_t & temp_storage_bytes,
@@ -2759,7 +2775,7 @@ struct DeviceSelect
27592775 ValueOutputIteratorT d_values_out,
27602776 NumSelectedIteratorT d_num_selected_out,
27612777 NumItemsT num_items,
2762- cudaStream_t stream = nullptr )
2778+ const EnvT& env = {} )
27632779 {
27642780 return UniqueByKey (
27652781 d_temp_storage,
@@ -2771,7 +2787,7 @@ struct DeviceSelect
27712787 d_num_selected_out,
27722788 num_items,
27732789 ::cuda::std::equal_to<>{},
2774- stream );
2790+ env );
27752791 }
27762792};
27772793
0 commit comments