Skip to content

Commit e67552a

Browse files
committed
[CUB] Refactor DeviceHistogram::HistogramRange to always take an environment
We want to be able to pass tunings to the APIs that take user provided memory.
1 parent ada45f2 commit e67552a

2 files changed

Lines changed: 128 additions & 12 deletions

File tree

cub/cub/device/device_histogram.cuh

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,9 @@ public:
919919
//! **[inferred]** Signed integer type for sequence offsets, list lengths,
920920
//! pointer differences, etc. @offset_size1
921921
//!
922+
//! @tparam EnvT
923+
//! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
924+
//!
922925
//! @param[in] d_temp_storage
923926
//! @devicestorage
924927
//!
@@ -944,11 +947,15 @@ public:
944947
//! @param[in] num_samples
945948
//! The number of data samples per row in the region of interest
946949
//!
947-
//! @param[in] stream
950+
//! @param[in] env
948951
//! @rst
949-
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
952+
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
950953
//! @endrst
951-
template <typename SampleIteratorT, typename CounterT, typename LevelT, typename OffsetT>
954+
template <typename SampleIteratorT,
955+
typename CounterT,
956+
typename LevelT,
957+
typename OffsetT,
958+
typename EnvT = ::cuda::std::execution::env<>>
952959
CUB_RUNTIME_FUNCTION static cudaError_t HistogramRange(
953960
void* d_temp_storage,
954961
size_t& temp_storage_bytes,
@@ -957,7 +964,7 @@ public:
957964
int num_levels,
958965
const LevelT* d_levels,
959966
OffsetT num_samples,
960-
cudaStream_t stream = nullptr)
967+
const EnvT& env = {})
961968
{
962969
/// The sample value type of the input iterator
963970
using SampleT = cub::detail::it_value_t<SampleIteratorT>;
@@ -971,7 +978,7 @@ public:
971978
num_samples,
972979
(OffsetT) 1,
973980
(size_t) (sizeof(SampleT) * num_samples),
974-
stream);
981+
env);
975982
}
976983

977984
//! @rst
@@ -1051,6 +1058,9 @@ public:
10511058
//! **[inferred]** Signed integer type for sequence offsets, list lengths,
10521059
//! pointer differences, etc. @offset_size1
10531060
//!
1061+
//! @tparam EnvT
1062+
//! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
1063+
//!
10541064
//! @param[in] d_temp_storage
10551065
//! @devicestorage
10561066
//!
@@ -1083,11 +1093,15 @@ public:
10831093
//! The number of bytes between starts of consecutive rows in the region
10841094
//! of interest
10851095
//!
1086-
//! @param[in] stream
1096+
//! @param[in] env
10871097
//! @rst
1088-
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
1098+
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
10891099
//! @endrst
1090-
template <typename SampleIteratorT, typename CounterT, typename LevelT, typename OffsetT>
1100+
template <typename SampleIteratorT,
1101+
typename CounterT,
1102+
typename LevelT,
1103+
typename OffsetT,
1104+
typename EnvT = ::cuda::std::execution::env<>>
10911105
CUB_RUNTIME_FUNCTION static cudaError_t HistogramRange(
10921106
void* d_temp_storage,
10931107
size_t& temp_storage_bytes,
@@ -1098,7 +1112,7 @@ public:
10981112
OffsetT num_row_samples,
10991113
OffsetT num_rows,
11001114
size_t row_stride_bytes,
1101-
cudaStream_t stream = nullptr)
1115+
const EnvT& env = {})
11021116
{
11031117
return MultiHistogramRange<1, 1>(
11041118
d_temp_storage,
@@ -1110,7 +1124,7 @@ public:
11101124
num_row_samples,
11111125
num_rows,
11121126
row_stride_bytes,
1113-
stream);
1127+
env);
11141128
}
11151129

11161130
//! @rst
@@ -2162,7 +2176,7 @@ public:
21622176
int num_levels,
21632177
const LevelT* d_levels,
21642178
OffsetT num_samples,
2165-
EnvT env = {})
2179+
const EnvT& env = {})
21662180
{
21672181
using SampleT = cub::detail::it_value_t<SampleIteratorT>;
21682182
return MultiHistogramRange<1, 1>(
@@ -2268,7 +2282,7 @@ public:
22682282
OffsetT num_row_samples,
22692283
OffsetT num_rows,
22702284
size_t row_stride_bytes,
2271-
EnvT env = {})
2285+
const EnvT& env = {})
22722286
{
22732287
return MultiHistogramRange<1, 1>(
22742288
d_samples,

cub/test/catch2_test_device_histogram_env.cu

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,108 @@ TEST_CASE("DeviceHistogram::HistogramRange works with default environment", "[hi
190190
REQUIRE(d_histogram == expected);
191191
}
192192

193+
TEST_CASE("DeviceHistogram::HistogramRange works with user provided memory and environment", "[histogram][device]")
194+
{
195+
auto d_samples = c2h::device_vector<float>{2.2f, 6.1f, 7.5f, 2.9f, 3.5f, 0.3f, 2.9f, 2.1f};
196+
int num_samples = static_cast<int>(d_samples.size());
197+
auto d_levels = c2h::device_vector<float>{0.0f, 2.0f, 4.0f, 6.0f, 8.0f};
198+
int num_levels = static_cast<int>(d_levels.size());
199+
auto d_histogram = c2h::device_vector<int>(num_levels - 1, 0);
200+
201+
c2h::device_vector<int> expected{1, 5, 0, 2};
202+
203+
size_t expected_bytes_allocated{};
204+
auto error = cub::DeviceHistogram::HistogramRange(
205+
nullptr,
206+
expected_bytes_allocated,
207+
thrust::raw_pointer_cast(d_samples.data()),
208+
thrust::raw_pointer_cast(d_histogram.data()),
209+
num_levels,
210+
thrust::raw_pointer_cast(d_levels.data()),
211+
num_samples);
212+
REQUIRE(error == cudaSuccess);
213+
REQUIRE(cudaSuccess == cudaPeekAtLastError());
214+
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
215+
216+
auto d_temp = c2h::device_vector<uint8_t>(expected_bytes_allocated, thrust::no_init);
217+
void* temp_storage = thrust::raw_pointer_cast(d_temp.data());
218+
219+
auto test_histogram_range = [&](const auto& env) {
220+
size_t num_bytes = 0;
221+
error = cub::DeviceHistogram::HistogramRange(
222+
nullptr,
223+
num_bytes,
224+
thrust::raw_pointer_cast(d_samples.data()),
225+
thrust::raw_pointer_cast(d_histogram.data()),
226+
num_levels,
227+
thrust::raw_pointer_cast(d_levels.data()),
228+
num_samples,
229+
env);
230+
REQUIRE(error == cudaSuccess);
231+
REQUIRE(cudaSuccess == cudaPeekAtLastError());
232+
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
233+
REQUIRE(expected_bytes_allocated == num_bytes);
234+
235+
error = cub::DeviceHistogram::HistogramRange(
236+
temp_storage,
237+
num_bytes,
238+
thrust::raw_pointer_cast(d_samples.data()),
239+
thrust::raw_pointer_cast(d_histogram.data()),
240+
num_levels,
241+
thrust::raw_pointer_cast(d_levels.data()),
242+
num_samples,
243+
env);
244+
REQUIRE(error == cudaSuccess);
245+
REQUIRE(cudaSuccess == cudaPeekAtLastError());
246+
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
247+
248+
// Verify result
249+
REQUIRE(d_histogram == expected);
250+
};
251+
252+
int current_device;
253+
error = cudaGetDevice(&current_device);
254+
REQUIRE(error == cudaSuccess);
255+
256+
SECTION("DeviceHistogram::HistogramRange works with cudaStream_t")
257+
{
258+
cuda::stream stream{cuda::devices[current_device]};
259+
test_histogram_range(stream.get());
260+
}
261+
262+
SECTION("DeviceHistogram::HistogramRange works with cuda::stream")
263+
{
264+
cuda::stream stream{cuda::devices[current_device]};
265+
test_histogram_range(stream);
266+
}
267+
268+
SECTION("DeviceHistogram::HistogramRange works with cuda::stream_ref")
269+
{
270+
cuda::stream stream{cuda::devices[current_device]};
271+
cuda::stream_ref stream_ref{stream};
272+
test_histogram_range(stream_ref);
273+
}
274+
275+
SECTION("DeviceHistogram::HistogramRange works with cuda::std::execution::env")
276+
{
277+
cuda::std::execution::env env{};
278+
test_histogram_range(env);
279+
}
280+
281+
SECTION("DeviceHistogram::HistogramRange works with cuda::execution::gpu")
282+
{
283+
const auto policy = cuda::execution::gpu;
284+
test_histogram_range(policy);
285+
}
286+
287+
SECTION("DeviceHistogram::HistogramRange works with cuda::execution::gpu with stream")
288+
{
289+
cuda::stream stream{cuda::devices[current_device]};
290+
const auto policy = cuda::execution::gpu.with(cuda::get_stream, stream);
291+
test_histogram_range(policy);
292+
}
293+
}
294+
193295
TEST_CASE("DeviceHistogram::MultiHistogramEven works with default environment", "[histogram][device]")
194296
{
195297
[[maybe_unused]] constexpr int NUM_CHANNELS = 4;

0 commit comments

Comments
 (0)