Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions cub/cub/device/device_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,9 @@ public:
//! **[inferred]** Signed integer type for sequence offsets, list lengths,
//! pointer differences, etc. @offset_size1
//!
//! @tparam EnvT
//! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
//!
//! @param[in] d_temp_storage
//! @devicestorage
//!
Expand All @@ -944,11 +947,15 @@ public:
//! @param[in] num_samples
//! The number of data samples per row in the region of interest
//!
//! @param[in] stream
//! @param[in] env
//! @rst
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
//! @endrst
template <typename SampleIteratorT, typename CounterT, typename LevelT, typename OffsetT>
template <typename SampleIteratorT,
typename CounterT,
typename LevelT,
typename OffsetT,
typename EnvT = ::cuda::std::execution::env<>>
CUB_RUNTIME_FUNCTION static cudaError_t HistogramRange(
void* d_temp_storage,
size_t& temp_storage_bytes,
Expand All @@ -957,7 +964,7 @@ public:
int num_levels,
const LevelT* d_levels,
OffsetT num_samples,
cudaStream_t stream = nullptr)
const EnvT& env = {})
{
/// The sample value type of the input iterator
using SampleT = cub::detail::it_value_t<SampleIteratorT>;
Expand All @@ -971,7 +978,7 @@ public:
num_samples,
(OffsetT) 1,
(size_t) (sizeof(SampleT) * num_samples),
stream);
env);
}

//! @rst
Expand Down Expand Up @@ -1051,6 +1058,9 @@ public:
//! **[inferred]** Signed integer type for sequence offsets, list lengths,
//! pointer differences, etc. @offset_size1
//!
//! @tparam EnvT
//! **[inferred]** Environment type (e.g., `cuda::std::execution::env<...>`)
//!
//! @param[in] d_temp_storage
//! @devicestorage
//!
Expand Down Expand Up @@ -1083,11 +1093,15 @@ public:
//! The number of bytes between starts of consecutive rows in the region
//! of interest
//!
//! @param[in] stream
//! @param[in] env
//! @rst
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
//! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``.
//! @endrst
template <typename SampleIteratorT, typename CounterT, typename LevelT, typename OffsetT>
template <typename SampleIteratorT,
typename CounterT,
typename LevelT,
typename OffsetT,
typename EnvT = ::cuda::std::execution::env<>>
CUB_RUNTIME_FUNCTION static cudaError_t HistogramRange(
void* d_temp_storage,
size_t& temp_storage_bytes,
Expand All @@ -1098,7 +1112,7 @@ public:
OffsetT num_row_samples,
OffsetT num_rows,
size_t row_stride_bytes,
cudaStream_t stream = nullptr)
const EnvT& env = {})
{
return MultiHistogramRange<1, 1>(
d_temp_storage,
Expand All @@ -1110,7 +1124,7 @@ public:
num_row_samples,
num_rows,
row_stride_bytes,
stream);
env);
}

//! @rst
Expand Down Expand Up @@ -2162,7 +2176,7 @@ public:
int num_levels,
const LevelT* d_levels,
OffsetT num_samples,
EnvT env = {})
const EnvT& env = {})
{
using SampleT = cub::detail::it_value_t<SampleIteratorT>;
return MultiHistogramRange<1, 1>(
Expand Down Expand Up @@ -2268,7 +2282,7 @@ public:
OffsetT num_row_samples,
OffsetT num_rows,
size_t row_stride_bytes,
EnvT env = {})
const EnvT& env = {})
{
return MultiHistogramRange<1, 1>(
d_samples,
Expand Down
102 changes: 102 additions & 0 deletions cub/test/catch2_test_device_histogram_env.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,108 @@ TEST_CASE("DeviceHistogram::HistogramRange works with default environment", "[hi
REQUIRE(d_histogram == expected);
}

TEST_CASE("DeviceHistogram::HistogramRange works with user provided memory and environment", "[histogram][device]")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important: guard the new test by #if TEST_LAUNCH == 0

{
auto d_samples = c2h::device_vector<float>{2.2f, 6.1f, 7.5f, 2.9f, 3.5f, 0.3f, 2.9f, 2.1f};
int num_samples = static_cast<int>(d_samples.size());
auto d_levels = c2h::device_vector<float>{0.0f, 2.0f, 4.0f, 6.0f, 8.0f};
int num_levels = static_cast<int>(d_levels.size());
auto d_histogram = c2h::device_vector<int>(num_levels - 1, 0);

c2h::device_vector<int> expected{1, 5, 0, 2};

size_t expected_bytes_allocated{};
auto error = cub::DeviceHistogram::HistogramRange(
nullptr,
expected_bytes_allocated,
thrust::raw_pointer_cast(d_samples.data()),
thrust::raw_pointer_cast(d_histogram.data()),
num_levels,
thrust::raw_pointer_cast(d_levels.data()),
num_samples);
REQUIRE(error == cudaSuccess);
REQUIRE(cudaSuccess == cudaPeekAtLastError());
REQUIRE(cudaSuccess == cudaDeviceSynchronize());

auto d_temp = c2h::device_vector<uint8_t>(expected_bytes_allocated, thrust::no_init);
void* temp_storage = thrust::raw_pointer_cast(d_temp.data());

auto test_histogram_range = [&](const auto& env) {
size_t num_bytes = 0;
error = cub::DeviceHistogram::HistogramRange(
nullptr,
num_bytes,
thrust::raw_pointer_cast(d_samples.data()),
thrust::raw_pointer_cast(d_histogram.data()),
num_levels,
thrust::raw_pointer_cast(d_levels.data()),
num_samples,
env);
REQUIRE(error == cudaSuccess);
REQUIRE(cudaSuccess == cudaPeekAtLastError());
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
REQUIRE(expected_bytes_allocated == num_bytes);

error = cub::DeviceHistogram::HistogramRange(
temp_storage,
num_bytes,
thrust::raw_pointer_cast(d_samples.data()),
thrust::raw_pointer_cast(d_histogram.data()),
num_levels,
thrust::raw_pointer_cast(d_levels.data()),
num_samples,
env);
REQUIRE(error == cudaSuccess);
REQUIRE(cudaSuccess == cudaPeekAtLastError());
REQUIRE(cudaSuccess == cudaDeviceSynchronize());

// Verify result
REQUIRE(d_histogram == expected);
};

int current_device;
error = cudaGetDevice(&current_device);
REQUIRE(error == cudaSuccess);

SECTION("DeviceHistogram::HistogramRange works with cudaStream_t")
{
cuda::stream stream{cuda::devices[current_device]};
test_histogram_range(stream.get());
}

SECTION("DeviceHistogram::HistogramRange works with cuda::stream")
{
cuda::stream stream{cuda::devices[current_device]};
test_histogram_range(stream);
}

SECTION("DeviceHistogram::HistogramRange works with cuda::stream_ref")
{
cuda::stream stream{cuda::devices[current_device]};
cuda::stream_ref stream_ref{stream};
test_histogram_range(stream_ref);
}

SECTION("DeviceHistogram::HistogramRange works with cuda::std::execution::env")
{
cuda::std::execution::env env{};
test_histogram_range(env);
}

SECTION("DeviceHistogram::HistogramRange works with cuda::execution::gpu")
{
const auto policy = cuda::execution::gpu;
test_histogram_range(policy);
}

SECTION("DeviceHistogram::HistogramRange works with cuda::execution::gpu with stream")
{
cuda::stream stream{cuda::devices[current_device]};
const auto policy = cuda::execution::gpu.with(cuda::get_stream, stream);
test_histogram_range(policy);
}
}

TEST_CASE("DeviceHistogram::MultiHistogramEven works with default environment", "[histogram][device]")
{
[[maybe_unused]] constexpr int NUM_CHANNELS = 4;
Expand Down
Loading