Skip to content

Commit af8cce4

Browse files
authored
[Backport branch/3.3.x] [libcu++] Add missing bit_cast in the buffer construction (#8420) (#8425)
* Add missing bit_cast in the buffer construction * Fix MSVC cast
1 parent bf0584a commit af8cce4

2 files changed

Lines changed: 17 additions & 8 deletions

File tree

libcudacxx/include/cuda/__driver/driver_api.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#if _CCCL_HAS_CTK() && !_CCCL_COMPILER(NVRTC)
2525

26+
# include <cuda/std/__bit/bit_cast.h>
2627
# include <cuda/std/__cstddef/types.h>
2728
# include <cuda/std/__exception/cuda_error.h>
2829
# include <cuda/std/__host_stdlib/stdexcept>
@@ -348,20 +349,23 @@ _CCCL_HOST_API void __memsetAsync(void* __dst, _Tp __value, ::cuda::std::size_t
348349
if constexpr (sizeof(_Tp) == 1)
349350
{
350351
static auto __driver_fn = _CCCLRT_GET_DRIVER_FUNCTION(cuMemsetD8Async);
352+
auto __bits = ::cuda::std::bit_cast<unsigned char>(__value);
351353
::cuda::__driver::__call_driver_fn(
352-
__driver_fn, "Failed to perform a memset", reinterpret_cast<::CUdeviceptr>(__dst), __value, __count, __stream);
354+
__driver_fn, "Failed to perform a memset", reinterpret_cast<::CUdeviceptr>(__dst), __bits, __count, __stream);
353355
}
354356
else if constexpr (sizeof(_Tp) == 2)
355357
{
356358
static auto __driver_fn = _CCCLRT_GET_DRIVER_FUNCTION(cuMemsetD16Async);
359+
auto __bits = ::cuda::std::bit_cast<unsigned short>(__value);
357360
::cuda::__driver::__call_driver_fn(
358-
__driver_fn, "Failed to perform a memset", reinterpret_cast<::CUdeviceptr>(__dst), __value, __count, __stream);
361+
__driver_fn, "Failed to perform a memset", reinterpret_cast<::CUdeviceptr>(__dst), __bits, __count, __stream);
359362
}
360363
else if constexpr (sizeof(_Tp) == 4)
361364
{
362365
static auto __driver_fn = _CCCLRT_GET_DRIVER_FUNCTION(cuMemsetD32Async);
366+
auto __bits = ::cuda::std::bit_cast<unsigned int>(__value);
363367
::cuda::__driver::__call_driver_fn(
364-
__driver_fn, "Failed to perform a memset", reinterpret_cast<::CUdeviceptr>(__dst), __value, __count, __stream);
368+
__driver_fn, "Failed to perform a memset", reinterpret_cast<::CUdeviceptr>(__dst), __bits, __count, __stream);
365369
}
366370
else
367371
{

libcudacxx/test/libcudacxx/cuda/containers/buffer/helper.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ struct equal_to_value
136136
};
137137

138138
template <class Buffer>
139-
bool equal_size_value(const Buffer& buf, const size_t size, const int value)
139+
bool equal_size_value(const Buffer& buf, const size_t size, const typename Buffer::value_type value)
140140
{
141141
if constexpr (Buffer::properties_list::has_property(cuda::mr::host_accessible{}))
142142
{
@@ -221,11 +221,16 @@ struct extract_properties<cuda::buffer<T, Properties...>>
221221
};
222222

223223
#if _CCCL_CTK_AT_LEAST(12, 9)
224-
using test_types = c2h::type_list<cuda::buffer<int, cuda::mr::host_accessible>,
225-
cuda::buffer<unsigned long long, cuda::mr::device_accessible>,
226-
cuda::buffer<int, cuda::mr::host_accessible, cuda::mr::device_accessible>>;
224+
using test_types =
225+
c2h::type_list<cuda::buffer<int, cuda::mr::host_accessible>,
226+
cuda::buffer<unsigned long long, cuda::mr::device_accessible>,
227+
cuda::buffer<short, cuda::mr::device_accessible>,
228+
cuda::buffer<float, cuda::mr::device_accessible>,
229+
cuda::buffer<int, cuda::mr::host_accessible, cuda::mr::device_accessible>>;
227230
#else // ^^^ _CCCL_CTK_AT_LEAST(12, 9) ^^^ / vvv _CCCL_CTK_BELOW(12, 9) vvv
228-
using test_types = c2h::type_list<cuda::buffer<int, cuda::mr::device_accessible>>;
231+
using test_types = c2h::type_list<cuda::buffer<int, cuda::mr::device_accessible>,
232+
cuda::buffer<short, cuda::mr::device_accessible>,
233+
cuda::buffer<float, cuda::mr::device_accessible>>;
229234
#endif // ^^^ _CCCL_CTK_BELOW(12, 9) ^^^
230235

231236
#endif // CUDA_TEST_CONTAINER_VECTOR_HELPER_H

0 commit comments

Comments
 (0)