@@ -454,6 +454,9 @@ struct DevicePartition
454454 // ! @tparam NumItemsT
455455 // ! **[inferred]** Type of num_items
456456 // !
457+ // ! @tparam EnvT
458+ // ! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
459+ // !
457460 // ! @param[in] d_temp_storage
458461 // ! @devicestorage
459462 // !
@@ -475,15 +478,14 @@ struct DevicePartition
475478 // ! @param[in] select_op
476479 // ! Unary selection operator
477480 // !
478- // ! @param[in] stream
479- // ! @rst
480- // ! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
481- // ! @endrst
481+ // ! @param[in] env
482+ // ! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
482483 template <typename InputIteratorT,
483484 typename OutputIteratorT,
484485 typename NumSelectedIteratorT,
485486 typename SelectOp,
486- typename NumItemsT>
487+ typename NumItemsT,
488+ typename EnvT = ::cuda::std::execution::env<>>
487489 CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t
488490 If (void * d_temp_storage,
489491 size_t & temp_storage_bytes,
@@ -492,31 +494,36 @@ struct DevicePartition
492494 NumSelectedIteratorT d_num_selected_out,
493495 NumItemsT num_items,
494496 SelectOp select_op,
495- cudaStream_t stream = nullptr )
497+ const EnvT& env = {} )
496498 {
497499 _CCCL_NVTX_RANGE_SCOPE_IF (d_temp_storage, " cub::DevicePartition::If" );
498- using ChooseOffsetT = detail::choose_signed_offset<NumItemsT>;
499- using OffsetT = typename ChooseOffsetT::type; // Signed integer type for global offsets
500- using FlagIterator = NullType*; // FlagT iterator type (not used)
501- using EqualityOp = NullType; // Equality operator (not used)
500+
501+ using choose_offset_t = detail::choose_signed_offset<NumItemsT>;
502+ using offset_t = typename choose_offset_t ::type;
503+ using default_policy_selector = detail::select::
504+ policy_selector_from_types<InputIteratorT, NullType*, OutputIteratorT, offset_t , SelectImpl::Partition>;
502505
503506 // Check if the number of items exceeds the range covered by the selected signed offset type
504- if (const cudaError_t error = ChooseOffsetT ::is_exceeding_offset_type (num_items))
507+ if (const auto error = choose_offset_t ::is_exceeding_offset_type (num_items))
505508 {
506509 return error;
507510 }
508511
509- return detail::select::dispatch<SelectImpl::Partition>(
510- d_temp_storage,
511- temp_storage_bytes,
512- d_in,
513- FlagIterator{nullptr },
514- d_out,
515- d_num_selected_out,
516- select_op,
517- EqualityOp{},
518- static_cast <OffsetT>(num_items),
519- stream);
512+ return detail::dispatch_with_env_and_tuning<default_policy_selector>(
513+ d_temp_storage, temp_storage_bytes, env, [&](auto policy_selector, void * storage, size_t & bytes, auto stream) {
514+ return detail::select::dispatch<SelectImpl::Partition>(
515+ storage,
516+ bytes,
517+ d_in,
518+ static_cast <NullType*>(nullptr ),
519+ d_out,
520+ d_num_selected_out,
521+ select_op,
522+ NullType{},
523+ static_cast <offset_t >(num_items),
524+ stream,
525+ policy_selector);
526+ });
520527 }
521528
522529 // ! @rst
@@ -599,7 +606,7 @@ struct DevicePartition
599606 NumSelectedIteratorT d_num_selected_out,
600607 NumItemsT num_items,
601608 SelectOp select_op,
602- EnvT env = {})
609+ const EnvT& env = {})
603610 {
604611 _CCCL_NVTX_RANGE_SCOPE (" cub::DevicePartition::If" );
605612
@@ -780,6 +787,9 @@ struct DevicePartition
780787 // ! @tparam NumItemsT
781788 // ! **[inferred]** Type of num_items
782789 // !
790+ // ! @tparam EnvT
791+ // ! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
792+ // !
783793 // ! @param[in] d_temp_storage
784794 // ! @devicestorage
785795 // !
@@ -814,18 +824,17 @@ struct DevicePartition
814824 // ! @param[in] select_second_part_op
815825 // ! Unary selection operator to select `d_second_part_out`
816826 // !
817- // ! @param[in] stream
818- // ! @rst
819- // ! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
820- // ! @endrst
827+ // ! @param[in] env
828+ // ! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
821829 template <typename InputIteratorT,
822830 typename FirstOutputIteratorT,
823831 typename SecondOutputIteratorT,
824832 typename UnselectedOutputIteratorT,
825833 typename NumSelectedIteratorT,
826834 typename SelectFirstPartOp,
827835 typename SelectSecondPartOp,
828- typename NumItemsT>
836+ typename NumItemsT,
837+ typename EnvT = ::cuda::std::execution::env<>>
829838 CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t
830839 If (void * d_temp_storage,
831840 size_t & temp_storage_bytes,
@@ -837,7 +846,7 @@ struct DevicePartition
837846 NumItemsT num_items,
838847 SelectFirstPartOp select_first_part_op,
839848 SelectSecondPartOp select_second_part_op,
840- cudaStream_t stream = nullptr )
849+ const EnvT& env = {} )
841850 {
842851 _CCCL_NVTX_RANGE_SCOPE_IF (d_temp_storage, " cub::DevicePartition::If" );
843852 using choose_offset_t = detail::choose_signed_offset<NumItemsT>;
@@ -849,18 +858,26 @@ struct DevicePartition
849858 }
850859
851860 using offset_t = typename choose_offset_t ::type;
852- return detail::three_way_partition::dispatch (
853- d_temp_storage,
854- temp_storage_bytes,
855- d_in,
856- d_first_part_out,
857- d_second_part_out,
858- d_unselected_out,
859- d_num_selected_out,
860- select_first_part_op,
861- select_second_part_op,
862- static_cast <offset_t >(num_items),
863- stream);
861+ using default_policy_selector =
862+ detail::three_way_partition::policy_selector_from_types<detail::it_value_t <InputIteratorT>,
863+ detail::three_way_partition::per_partition_offset_t >;
864+
865+ return detail::dispatch_with_env_and_tuning<default_policy_selector>(
866+ d_temp_storage, temp_storage_bytes, env, [&](auto policy_selector, void * storage, size_t & bytes, auto stream) {
867+ return detail::three_way_partition::dispatch (
868+ storage,
869+ bytes,
870+ d_in,
871+ d_first_part_out,
872+ d_second_part_out,
873+ d_unselected_out,
874+ d_num_selected_out,
875+ select_first_part_op,
876+ select_second_part_op,
877+ static_cast <offset_t >(num_items),
878+ stream,
879+ policy_selector);
880+ });
864881 }
865882
866883 // ! @rst
@@ -984,7 +1001,7 @@ struct DevicePartition
9841001 NumItemsT num_items,
9851002 SelectFirstPartOp select_first_part_op,
9861003 SelectSecondPartOp select_second_part_op,
987- EnvT env = {})
1004+ const EnvT& env = {})
9881005 {
9891006 _CCCL_NVTX_RANGE_SCOPE (" cub::DevicePartition::If" );
9901007
0 commit comments