Skip to content

Commit 45c5b58

Browse files
committed
[CUB] Refactor DevicePartition::Flagged to always take an environment
We want to be able to pass tunings to the APIs that take user provided memory Make sure we can pass any environment or stream type to them
1 parent 7051612 commit 45c5b58

4 files changed

Lines changed: 137 additions & 28 deletions

File tree

cub/cub/device/device_partition.cuh

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ struct DevicePartition
180180
//! @tparam NumItemsT
181181
//! **[inferred]** Type of num_items
182182
//!
183+
//! @tparam EnvT
184+
//! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
185+
//!
183186
//! @param[in] d_temp_storage
184187
//! @devicestorage
185188
//!
@@ -202,15 +205,14 @@ struct DevicePartition
202205
//! @param[in] num_items
203206
//! Total number of items to select from
204207
//!
205-
//! @param[in] stream
206-
//! @rst
207-
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
208-
//! @endrst
208+
//! @param[in] env
209+
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
209210
template <typename InputIteratorT,
210211
typename FlagIterator,
211212
typename OutputIteratorT,
212213
typename NumSelectedIteratorT,
213-
typename NumItemsT>
214+
typename NumItemsT,
215+
typename EnvT = ::cuda::std::execution::env<>>
214216
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Flagged(
215217
void* d_temp_storage,
216218
size_t& temp_storage_bytes,
@@ -219,31 +221,35 @@ struct DevicePartition
219221
OutputIteratorT d_out,
220222
NumSelectedIteratorT d_num_selected_out,
221223
NumItemsT num_items,
222-
cudaStream_t stream = nullptr)
224+
const EnvT& env = {})
223225
{
224226
_CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DevicePartition::Flagged");
225-
using ChooseOffsetT = detail::choose_signed_offset<NumItemsT>;
226-
using OffsetT = typename ChooseOffsetT::type; // Signed integer type for global offsets
227-
using SelectOp = NullType; // Selection op (not used)
228-
using EqualityOp = NullType; // Equality operator (not used)
227+
using choose_offset_t = detail::choose_signed_offset<NumItemsT>;
228+
using offset_t = typename choose_offset_t::type;
229+
using default_policy_selector = detail::select::
230+
policy_selector_from_types<InputIteratorT, FlagIterator, OutputIteratorT, offset_t, SelectImpl::Partition>;
229231

230232
// Check if the number of items exceeds the range covered by the selected signed offset type
231-
if (const cudaError_t error = ChooseOffsetT::is_exceeding_offset_type(num_items))
233+
if (const auto error = choose_offset_t::is_exceeding_offset_type(num_items))
232234
{
233235
return error;
234236
}
235237

236-
return detail::select::dispatch<SelectImpl::Partition>(
237-
d_temp_storage,
238-
temp_storage_bytes,
239-
d_in,
240-
d_flags,
241-
d_out,
242-
d_num_selected_out,
243-
SelectOp{},
244-
EqualityOp{},
245-
static_cast<OffsetT>(num_items),
246-
stream);
238+
return detail::dispatch_with_env_and_tuning<default_policy_selector>(
239+
d_temp_storage, temp_storage_bytes, env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) {
240+
return detail::select::dispatch<SelectImpl::Partition>(
241+
storage,
242+
bytes,
243+
d_in,
244+
d_flags,
245+
d_out,
246+
d_num_selected_out,
247+
NullType{},
248+
NullType{},
249+
static_cast<offset_t>(num_items),
250+
stream,
251+
policy_selector);
252+
});
247253
}
248254

249255
//! @rst
@@ -329,7 +335,7 @@ struct DevicePartition
329335
OutputIteratorT d_out,
330336
NumSelectedIteratorT d_num_selected_out,
331337
NumItemsT num_items,
332-
EnvT env = {})
338+
const EnvT& env = {})
333339
{
334340
_CCCL_NVTX_RANGE_SCOPE("cub::DevicePartition::Flagged");
335341

cub/test/catch2_test_device_partition_flagged.cu

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
#include <thrust/reverse.h>
1111

1212
#include <cuda/cmath>
13+
#include <cuda/devices>
1314
#include <cuda/iterator>
15+
#include <cuda/std/execution>
1416
#include <cuda/std/iterator>
1517

1618
#include <algorithm>
@@ -202,6 +204,107 @@ C2H_TEST("DevicePartition::Flagged is stable", "[device][partition_flagged]")
202204
REQUIRE(reference == out);
203205
}
204206

207+
#if TEST_LAUNCH == 0
208+
C2H_TEST("DevicePartition::Flagged works with user provided memory and environment",
209+
"[device][partition_flagged]",
210+
all_types)
211+
{
212+
using type = typename c2h::get<0, TestType>;
213+
214+
const int num_items = GENERATE_COPY(take(2, random(1, 1000000)));
215+
c2h::device_vector<type> in(num_items, thrust::default_init);
216+
c2h::device_vector<type> out(num_items, thrust::default_init);
217+
c2h::gen(C2H_SEED(2), in);
218+
219+
c2h::device_vector<int> flags(num_items, thrust::no_init);
220+
c2h::gen(C2H_SEED(1), flags, 0, 1);
221+
222+
const int num_selected = static_cast<int>(thrust::count(c2h::device_policy, flags.begin(), flags.end(), 1));
223+
const c2h::host_vector<type> reference = get_reference(in, flags);
224+
225+
// Needs to be device accessible
226+
c2h::device_vector<int> num_selected_out(1, 0);
227+
int* d_num_selected_out = thrust::raw_pointer_cast(num_selected_out.data());
228+
229+
size_t expected_allocation_size = 0;
230+
auto error = cub::DevicePartition::Flagged(
231+
static_cast<void*>(nullptr),
232+
expected_allocation_size,
233+
in.begin(),
234+
flags.begin(),
235+
out.begin(),
236+
d_num_selected_out,
237+
num_items);
238+
REQUIRE(error == cudaSuccess);
239+
REQUIRE(cudaSuccess == cudaPeekAtLastError());
240+
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
241+
242+
auto d_temp = c2h::device_vector<uint8_t>(expected_allocation_size, thrust::no_init);
243+
void* temp_storage = thrust::raw_pointer_cast(d_temp.data());
244+
245+
auto test_partition_flagged = [&](const auto& env) {
246+
size_t num_bytes = 0;
247+
error = cub::DevicePartition::Flagged(
248+
static_cast<void*>(nullptr), num_bytes, in.begin(), flags.begin(), out.begin(), d_num_selected_out, num_items, env);
249+
REQUIRE(error == cudaSuccess);
250+
REQUIRE(cudaSuccess == cudaPeekAtLastError());
251+
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
252+
REQUIRE(expected_allocation_size == num_bytes);
253+
254+
error = cub::DevicePartition::Flagged(
255+
temp_storage, num_bytes, in.begin(), flags.begin(), out.begin(), d_num_selected_out, num_items, env);
256+
REQUIRE(error == cudaSuccess);
257+
REQUIRE(cudaSuccess == cudaPeekAtLastError());
258+
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
259+
260+
REQUIRE(num_selected == num_selected_out[0]);
261+
REQUIRE(reference == out);
262+
};
263+
264+
int current_device;
265+
error = cudaGetDevice(&current_device);
266+
REQUIRE(error == cudaSuccess);
267+
268+
SECTION("DevicePartition::Flagged works with cudaStream_t")
269+
{
270+
cuda::stream stream{cuda::devices[current_device]};
271+
test_partition_flagged(stream.get());
272+
}
273+
274+
SECTION("DevicePartition::Flagged works with cuda::stream")
275+
{
276+
cuda::stream stream{cuda::devices[current_device]};
277+
test_partition_flagged(stream);
278+
}
279+
280+
SECTION("DevicePartition::Flagged works with cuda::stream_ref")
281+
{
282+
cuda::stream stream{cuda::devices[current_device]};
283+
cuda::stream_ref stream_ref{stream};
284+
test_partition_flagged(stream_ref);
285+
}
286+
287+
SECTION("DevicePartition::Flagged works with cuda::std::execution::env")
288+
{
289+
cuda::std::execution::env env{};
290+
test_partition_flagged(env);
291+
}
292+
293+
SECTION("DevicePartition::Flagged works with cuda::execution::gpu")
294+
{
295+
const auto policy = cuda::execution::gpu;
296+
test_partition_flagged(policy);
297+
}
298+
299+
SECTION("DevicePartition::Flagged works with cuda::execution::gpu with stream")
300+
{
301+
cuda::stream stream{cuda::devices[current_device]};
302+
const auto policy = cuda::execution::gpu.with(cuda::get_stream, stream);
303+
test_partition_flagged(policy);
304+
}
305+
}
306+
#endif // TEST_LAUNCH == 0
307+
205308
C2H_TEST("DevicePartition::Flagged works with iterators", "[device][partition_flagged]", all_types)
206309
{
207310
using type = typename c2h::get<0, TestType>;

libcudacxx/include/cuda/std/__pstl/cuda/rotate.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct __pstl_dispatch<__pstl_algorithm::__rotate, __execution_backend::__cuda>
105105
__output_wrapper,
106106
static_cast<_OffsetType*>(nullptr),
107107
__count,
108-
nullptr);
108+
__policy);
109109

110110
{
111111
// Allocate memory for result
@@ -131,9 +131,9 @@ struct __pstl_dispatch<__pstl_algorithm::__rotate, __execution_backend::__cuda>
131131
__storage.template __get_raw_ptr<1>(),
132132
::cuda::transform_iterator{::cuda::counting_iterator<size_t>{0}, __rotate_fn{__count1}},
133133
::cuda::std::move(__output_wrapper),
134-
__storage.template __get_ptr<0>(),
134+
__storage.template __get_raw_ptr<0>(),
135135
__count,
136-
__stream.get());
136+
__policy);
137137
}
138138

139139
__stream.sync();

libcudacxx/include/cuda/std/__pstl/cuda/rotate_copy.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ struct __pstl_dispatch<__pstl_algorithm::__rotate_copy, __execution_backend::__c
106106
__output_wrapper,
107107
static_cast<_OffsetType*>(nullptr),
108108
__count,
109-
nullptr);
109+
__policy);
110110

111111
{
112112
// Allocate memory for result
@@ -123,7 +123,7 @@ struct __pstl_dispatch<__pstl_algorithm::__rotate_copy, __execution_backend::__c
123123
::cuda::std::move(__output_wrapper),
124124
__storage.template __get_ptr<0>(),
125125
__count,
126-
__stream.get());
126+
__policy);
127127
}
128128

129129
__stream.sync();

0 commit comments

Comments
 (0)