Skip to content

Commit 5425b25

Browse files
committed
[CUB] Refactor DevicePartition::If to always take an environment
We want to be able to pass tunings to the APIs that take user provided memory Make sure we can pass any environment or stream type to them
1 parent c94525f commit 5425b25

5 files changed

Lines changed: 171 additions & 49 deletions

File tree

cub/cub/device/device_partition.cuh

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,9 @@ struct DevicePartition
454454
//! @tparam NumItemsT
455455
//! **[inferred]** Type of num_items
456456
//!
457+
//! @tparam EnvT
458+
//! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
459+
//!
457460
//! @param[in] d_temp_storage
458461
//! @devicestorage
459462
//!
@@ -475,15 +478,14 @@ struct DevicePartition
475478
//! @param[in] select_op
476479
//! Unary selection operator
477480
//!
478-
//! @param[in] stream
479-
//! @rst
480-
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
481-
//! @endrst
481+
//! @param[in] env
482+
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
482483
template <typename InputIteratorT,
483484
typename OutputIteratorT,
484485
typename NumSelectedIteratorT,
485486
typename SelectOp,
486-
typename NumItemsT>
487+
typename NumItemsT,
488+
typename EnvT = ::cuda::std::execution::env<>>
487489
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t
488490
If(void* d_temp_storage,
489491
size_t& temp_storage_bytes,
@@ -492,31 +494,36 @@ struct DevicePartition
492494
NumSelectedIteratorT d_num_selected_out,
493495
NumItemsT num_items,
494496
SelectOp select_op,
495-
cudaStream_t stream = nullptr)
497+
const EnvT& env = {})
496498
{
497499
_CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DevicePartition::If");
498-
using ChooseOffsetT = detail::choose_signed_offset<NumItemsT>;
499-
using OffsetT = typename ChooseOffsetT::type; // Signed integer type for global offsets
500-
using FlagIterator = NullType*; // FlagT iterator type (not used)
501-
using EqualityOp = NullType; // Equality operator (not used)
500+
501+
using choose_offset_t = detail::choose_signed_offset<NumItemsT>;
502+
using offset_t = typename choose_offset_t::type;
503+
using default_policy_selector = detail::select::
504+
policy_selector_from_types<InputIteratorT, NullType*, OutputIteratorT, offset_t, SelectImpl::Partition>;
502505

503506
// Check if the number of items exceeds the range covered by the selected signed offset type
504-
if (const cudaError_t error = ChooseOffsetT::is_exceeding_offset_type(num_items))
507+
if (const auto error = choose_offset_t::is_exceeding_offset_type(num_items))
505508
{
506509
return error;
507510
}
508511

509-
return detail::select::dispatch<SelectImpl::Partition>(
510-
d_temp_storage,
511-
temp_storage_bytes,
512-
d_in,
513-
FlagIterator{nullptr},
514-
d_out,
515-
d_num_selected_out,
516-
select_op,
517-
EqualityOp{},
518-
static_cast<OffsetT>(num_items),
519-
stream);
512+
return detail::dispatch_with_env_and_tuning<default_policy_selector>(
513+
d_temp_storage, temp_storage_bytes, env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) {
514+
return detail::select::dispatch<SelectImpl::Partition>(
515+
storage,
516+
bytes,
517+
d_in,
518+
static_cast<NullType*>(nullptr),
519+
d_out,
520+
d_num_selected_out,
521+
select_op,
522+
NullType{},
523+
static_cast<offset_t>(num_items),
524+
stream,
525+
policy_selector);
526+
});
520527
}
521528

522529
//! @rst
@@ -599,7 +606,7 @@ struct DevicePartition
599606
NumSelectedIteratorT d_num_selected_out,
600607
NumItemsT num_items,
601608
SelectOp select_op,
602-
EnvT env = {})
609+
const EnvT& env = {})
603610
{
604611
_CCCL_NVTX_RANGE_SCOPE("cub::DevicePartition::If");
605612

@@ -780,6 +787,9 @@ struct DevicePartition
780787
//! @tparam NumItemsT
781788
//! **[inferred]** Type of num_items
782789
//!
790+
//! @tparam EnvT
791+
//! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
792+
//!
783793
//! @param[in] d_temp_storage
784794
//! @devicestorage
785795
//!
@@ -814,18 +824,17 @@ struct DevicePartition
814824
//! @param[in] select_second_part_op
815825
//! Unary selection operator to select `d_second_part_out`
816826
//!
817-
//! @param[in] stream
818-
//! @rst
819-
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
820-
//! @endrst
827+
//! @param[in] env
828+
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
821829
template <typename InputIteratorT,
822830
typename FirstOutputIteratorT,
823831
typename SecondOutputIteratorT,
824832
typename UnselectedOutputIteratorT,
825833
typename NumSelectedIteratorT,
826834
typename SelectFirstPartOp,
827835
typename SelectSecondPartOp,
828-
typename NumItemsT>
836+
typename NumItemsT,
837+
typename EnvT = ::cuda::std::execution::env<>>
829838
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t
830839
If(void* d_temp_storage,
831840
size_t& temp_storage_bytes,
@@ -837,7 +846,7 @@ struct DevicePartition
837846
NumItemsT num_items,
838847
SelectFirstPartOp select_first_part_op,
839848
SelectSecondPartOp select_second_part_op,
840-
cudaStream_t stream = nullptr)
849+
const EnvT& env = {})
841850
{
842851
_CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DevicePartition::If");
843852
using choose_offset_t = detail::choose_signed_offset<NumItemsT>;
@@ -849,18 +858,26 @@ struct DevicePartition
849858
}
850859

851860
using offset_t = typename choose_offset_t::type;
852-
return detail::three_way_partition::dispatch(
853-
d_temp_storage,
854-
temp_storage_bytes,
855-
d_in,
856-
d_first_part_out,
857-
d_second_part_out,
858-
d_unselected_out,
859-
d_num_selected_out,
860-
select_first_part_op,
861-
select_second_part_op,
862-
static_cast<offset_t>(num_items),
863-
stream);
861+
using default_policy_selector =
862+
detail::three_way_partition::policy_selector_from_types<detail::it_value_t<InputIteratorT>,
863+
detail::three_way_partition::per_partition_offset_t>;
864+
865+
return detail::dispatch_with_env_and_tuning<default_policy_selector>(
866+
d_temp_storage, temp_storage_bytes, env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) {
867+
return detail::three_way_partition::dispatch(
868+
storage,
869+
bytes,
870+
d_in,
871+
d_first_part_out,
872+
d_second_part_out,
873+
d_unselected_out,
874+
d_num_selected_out,
875+
select_first_part_op,
876+
select_second_part_op,
877+
static_cast<offset_t>(num_items),
878+
stream,
879+
policy_selector);
880+
});
864881
}
865882

866883
//! @rst
@@ -984,7 +1001,7 @@ struct DevicePartition
9841001
NumItemsT num_items,
9851002
SelectFirstPartOp select_first_part_op,
9861003
SelectSecondPartOp select_second_part_op,
987-
EnvT env = {})
1004+
const EnvT& env = {})
9881005
{
9891006
_CCCL_NVTX_RANGE_SCOPE("cub::DevicePartition::If");
9901007

cub/test/catch2_test_device_partition_if.cu

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
#include <thrust/reverse.h>
1111

1212
#include <cuda/cmath>
13+
#include <cuda/devices>
1314
#include <cuda/functional>
1415
#include <cuda/iterator>
16+
#include <cuda/std/execution>
1517
#include <cuda/std/iterator>
1618

1719
#include <algorithm>
@@ -170,6 +172,109 @@ C2H_TEST("DevicePartition::If is stable", "[device][partition_if]")
170172
REQUIRE(reference == out);
171173
}
172174

175+
#if TEST_LAUNCH == 0
176+
C2H_TEST("DevicePartition::If works with user provided memory and environment", "[device][partition_if]", all_types)
177+
{
178+
using type = typename c2h::get<0, TestType>;
179+
180+
const int num_items = GENERATE_COPY(take(2, random(1, 1000000)));
181+
c2h::device_vector<type> in(num_items, thrust::default_init);
182+
c2h::device_vector<type> out(num_items, thrust::default_init);
183+
c2h::gen(C2H_SEED(2), in);
184+
185+
// just pick one of the input elements as boundary
186+
less_than_t<type> le{in[num_items / 2]};
187+
188+
// Needs to be device accessible
189+
c2h::device_vector<int> num_selected_out(1, 0);
190+
int* d_first_num_selected_out = thrust::raw_pointer_cast(num_selected_out.data());
191+
192+
// Ensure that we create the same output as std
193+
c2h::host_vector<type> reference = in;
194+
// The main difference between stable_partition and DevicePartition::If is that the false partition is in reverse
195+
// order
196+
const auto boundary = std::stable_partition(reference.begin(), reference.end(), le);
197+
std::reverse(boundary, reference.end());
198+
199+
size_t expected_allocation_size = 0;
200+
auto error = cub::DevicePartition::If(
201+
static_cast<void*>(nullptr),
202+
expected_allocation_size,
203+
in.begin(),
204+
out.begin(),
205+
d_first_num_selected_out,
206+
num_items,
207+
le);
208+
REQUIRE(error == cudaSuccess);
209+
REQUIRE(cudaSuccess == cudaPeekAtLastError());
210+
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
211+
212+
auto d_temp = c2h::device_vector<uint8_t>(expected_allocation_size, thrust::no_init);
213+
void* temp_storage = thrust::raw_pointer_cast(d_temp.data());
214+
215+
auto test_partition_if = [&](const auto& env) {
216+
size_t num_bytes = 0;
217+
error = cub::DevicePartition::If(
218+
static_cast<void*>(nullptr), num_bytes, in.begin(), out.begin(), d_first_num_selected_out, num_items, le, env);
219+
REQUIRE(error == cudaSuccess);
220+
REQUIRE(cudaSuccess == cudaPeekAtLastError());
221+
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
222+
REQUIRE(expected_allocation_size == num_bytes);
223+
224+
error = cub::DevicePartition::If(
225+
temp_storage, num_bytes, in.begin(), out.begin(), d_first_num_selected_out, num_items, le, env);
226+
REQUIRE(error == cudaSuccess);
227+
REQUIRE(cudaSuccess == cudaPeekAtLastError());
228+
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
229+
230+
REQUIRE(num_selected_out[0] == cuda::std::distance(reference.begin(), boundary));
231+
REQUIRE(reference == out);
232+
};
233+
234+
int current_device;
235+
error = cudaGetDevice(&current_device);
236+
REQUIRE(error == cudaSuccess);
237+
238+
SECTION("DevicePartition::If works with cudaStream_t")
239+
{
240+
cuda::stream stream{cuda::devices[current_device]};
241+
test_partition_if(stream.get());
242+
}
243+
244+
SECTION("DevicePartition::If works with cuda::stream")
245+
{
246+
cuda::stream stream{cuda::devices[current_device]};
247+
test_partition_if(stream);
248+
}
249+
250+
SECTION("DevicePartition::If works with cuda::stream_ref")
251+
{
252+
cuda::stream stream{cuda::devices[current_device]};
253+
cuda::stream_ref stream_ref{stream};
254+
test_partition_if(stream_ref);
255+
}
256+
257+
SECTION("DevicePartition::If works with cuda::std::execution::env")
258+
{
259+
cuda::std::execution::env env{};
260+
test_partition_if(env);
261+
}
262+
263+
SECTION("DevicePartition::If works with cuda::execution::gpu")
264+
{
265+
const auto policy = cuda::execution::gpu;
266+
test_partition_if(policy);
267+
}
268+
269+
SECTION("DevicePartition::If works with cuda::execution::gpu with stream")
270+
{
271+
cuda::stream stream{cuda::devices[current_device]};
272+
const auto policy = cuda::execution::gpu.with(cuda::get_stream, stream);
273+
test_partition_if(policy);
274+
}
275+
}
276+
#endif // TEST_LAUNCH == 0
277+
173278
C2H_TEST("DevicePartition::If works with iterators", "[device][partition_if]", all_types)
174279
{
175280
using type = typename c2h::get<0, TestType>;

libcudacxx/include/cuda/std/__pstl/cuda/partition.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ struct __pstl_dispatch<__pstl_algorithm::__partition, __execution_backend::__cud
8787
static_cast<_OffsetType*>(nullptr),
8888
__count,
8989
__pred,
90-
nullptr);
90+
__policy);
9191

9292
{
9393
__temporary_storage<_OffsetType, value_type> __storage{__policy, __num_bytes, 1, __count};
@@ -114,7 +114,7 @@ struct __pstl_dispatch<__pstl_algorithm::__partition, __execution_backend::__cud
114114
__storage.template __get_ptr<0>(),
115115
__count,
116116
::cuda::std::move(__pred),
117-
__stream.get());
117+
__policy);
118118

119119
// Copy the result back from storage
120120
_CCCL_TRY_CUDA_API(

libcudacxx/include/cuda/std/__pstl/cuda/partition_copy.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ struct __pstl_dispatch<__pstl_algorithm::__partition_copy, __execution_backend::
9393
static_cast<_OffsetType*>(nullptr),
9494
__count,
9595
__pred,
96-
nullptr);
96+
__policy);
9797

9898
{
9999
__temporary_storage<_OffsetType> __storage{__policy, __num_bytes, 1};
@@ -109,7 +109,7 @@ struct __pstl_dispatch<__pstl_algorithm::__partition_copy, __execution_backend::
109109
__storage.template __get_ptr<0>(),
110110
__count,
111111
::cuda::std::move(__pred),
112-
__stream.get());
112+
__policy);
113113

114114
// Copy the result back from storage
115115
_CCCL_TRY_CUDA_API(

libcudacxx/include/cuda/std/__pstl/cuda/stable_partition.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ struct __pstl_dispatch<__pstl_algorithm::__stable_partition, __execution_backend
8888
static_cast<_OffsetType*>(nullptr),
8989
__count,
9090
__pred,
91-
__stream.get());
91+
__policy);
9292

9393
{
9494
__temporary_storage<_OffsetType, value_type> __storage{__policy, __num_bytes, 1, __count};
@@ -102,7 +102,7 @@ struct __pstl_dispatch<__pstl_algorithm::__stable_partition, __execution_backend
102102
__count,
103103
::cuda::always_true{},
104104
identity{},
105-
__stream.get());
105+
__policy);
106106

107107
// Run the kernel, the standard requires that the input and output range do not overlap
108108
_CCCL_TRY_CUDA_API(
@@ -115,7 +115,7 @@ struct __pstl_dispatch<__pstl_algorithm::__stable_partition, __execution_backend
115115
__storage.template __get_ptr<0>(),
116116
__count,
117117
::cuda::std::move(__pred),
118-
__stream.get());
118+
__policy);
119119

120120
// Copy the result back from storage
121121
_CCCL_TRY_CUDA_API(

0 commit comments

Comments
 (0)