23 #if defined(EIGEN_USE_SYCL) && \
24 !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H)
25 #define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
27 #include <CL/sycl.hpp>
28 #ifdef EIGEN_EXCEPTIONS
34 #include <unordered_map>
37 namespace TensorSycl {
40 using sycl_acc_target = cl::sycl::access::target;
41 using sycl_acc_mode = cl::sycl::access::mode;
46 using buffer_data_type_t =
uint8_t;
47 const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
48 const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
65 struct virtual_pointer_t {
68 base_ptr_t m_contents;
73 operator void *()
const {
return reinterpret_cast<void *
>(m_contents); }
78 operator base_ptr_t()
const {
return m_contents; }
84 virtual_pointer_t
operator+(
size_t off) {
return m_contents + off; }
87 bool operator<(virtual_pointer_t rhs)
const {
88 return (
static_cast<base_ptr_t
>(m_contents) <
89 static_cast<base_ptr_t
>(rhs.m_contents));
92 bool operator>(virtual_pointer_t rhs)
const {
93 return (
static_cast<base_ptr_t
>(m_contents) >
94 static_cast<base_ptr_t
>(rhs.m_contents));
100 bool operator==(virtual_pointer_t rhs)
const {
101 return (
static_cast<base_ptr_t
>(m_contents) ==
102 static_cast<base_ptr_t
>(rhs.m_contents));
108 bool operator!=(virtual_pointer_t rhs)
const {
118 virtual_pointer_t(
const void *ptr)
119 : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
125 virtual_pointer_t(base_ptr_t u) : m_contents(u){};
130 const virtual_pointer_t null_virtual_ptr =
nullptr;
136 static inline bool is_nullptr(virtual_pointer_t ptr) {
137 return (
static_cast<void *
>(ptr) ==
nullptr);
142 using buffer_t = cl::sycl::buffer_mem;
154 pMapNode_t(buffer_t b,
size_t size,
bool f)
155 : m_buffer{b}, m_size{
size}, m_free{f} {
156 m_buffer.set_final_data(
nullptr);
159 bool operator<=(
const pMapNode_t &rhs) {
return (m_size <= rhs.m_size); }
164 using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
171 typename pointerMap_t::iterator get_insertion_point(
size_t requiredSize) {
172 typename pointerMap_t::iterator retVal;
174 if (!m_freeList.empty()) {
176 for (
auto freeElem : m_freeList) {
177 if (freeElem->second.m_size >= requiredSize) {
181 m_freeList.erase(freeElem);
187 retVal = std::prev(m_pointerMap.end());
202 typename pointerMap_t::iterator get_node(
const virtual_pointer_t ptr) {
203 if (this->count() == 0) {
204 m_pointerMap.clear();
205 EIGEN_THROW_X(std::out_of_range(
"There are no pointers allocated\n"));
208 if (is_nullptr(ptr)) {
209 m_pointerMap.clear();
210 EIGEN_THROW_X(std::out_of_range(
"Cannot access null pointer\n"));
214 auto node = m_pointerMap.lower_bound(ptr);
217 if (node ==
std::end(m_pointerMap)) {
219 }
else if (node->first != ptr) {
220 if (node == std::begin(m_pointerMap)) {
221 m_pointerMap.clear();
223 std::out_of_range(
"The pointer is not registered in the map\n"));
235 template <
typename buffer_data_type = buffer_data_type_t>
236 cl::sycl::buffer<buffer_data_type, 1> get_buffer(
237 const virtual_pointer_t ptr) {
238 using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
244 auto node = get_node(ptr);
246 eigen_assert(ptr <
static_cast<virtual_pointer_t
>(node->second.m_size +
248 return *(
static_cast<sycl_buffer_t *
>(&node->second.m_buffer));
257 template <sycl_acc_mode access_mode = default_acc_mode,
258 sycl_acc_target access_target = default_acc_target,
259 typename buffer_data_type = buffer_data_type_t>
260 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
261 get_access(
const virtual_pointer_t ptr) {
262 auto buf = get_buffer<buffer_data_type>(ptr);
263 return buf.template get_access<access_mode, access_target>();
274 template <sycl_acc_mode access_mode = default_acc_mode,
275 sycl_acc_target access_target = default_acc_target,
276 typename buffer_data_type = buffer_data_type_t>
277 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
278 get_access(
const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
279 auto buf = get_buffer<buffer_data_type>(ptr);
280 return buf.template get_access<access_mode, access_target>(cgh);
286 inline std::ptrdiff_t get_offset(
const virtual_pointer_t ptr) {
289 auto node = get_node(ptr);
290 auto start = node->first;
293 return (ptr - start);
300 template <
typename buffer_data_type>
301 inline size_t get_element_offset(
const virtual_pointer_t ptr) {
302 return get_offset(ptr) /
sizeof(buffer_data_type);
308 PointerMapper(base_ptr_t baseAddress = 4096)
309 : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
310 if (m_baseAddress == 0) {
311 EIGEN_THROW_X(std::invalid_argument(
"Base address cannot be zero\n"));
318 PointerMapper(
const PointerMapper &) =
delete;
323 inline void clear() {
325 m_pointerMap.clear();
331 inline virtual_pointer_t add_pointer(
const buffer_t &b) {
332 return add_pointer_impl(b);
338 inline virtual_pointer_t add_pointer(buffer_t &&b) {
339 return add_pointer_impl(b);
348 void fuse_forward(
typename pointerMap_t::iterator &node) {
349 while (node != std::prev(m_pointerMap.end())) {
352 auto fwd_node = std::next(node);
353 if (!fwd_node->second.m_free) {
356 auto fwd_size = fwd_node->second.m_size;
357 m_freeList.erase(fwd_node);
358 m_pointerMap.erase(fwd_node);
360 node->second.m_size += fwd_size;
370 void fuse_backward(
typename pointerMap_t::iterator &node) {
371 while (node != m_pointerMap.begin()) {
374 auto prev_node = std::prev(node);
375 if (!prev_node->second.m_free) {
378 prev_node->second.m_size += node->second.m_size;
381 m_freeList.erase(node);
382 m_pointerMap.erase(node);
393 template <
bool ReUse = true>
394 void remove_pointer(
const virtual_pointer_t ptr) {
395 if (is_nullptr(ptr)) {
398 auto node = this->get_node(ptr);
400 node->second.m_free =
true;
401 m_freeList.emplace(node);
410 if (node == std::prev(m_pointerMap.end())) {
411 m_freeList.erase(node);
412 m_pointerMap.erase(node);
420 size_t count()
const {
return (m_pointerMap.size() - m_freeList.size()); }
427 template <
class BufferT>
428 virtual_pointer_t add_pointer_impl(BufferT b) {
429 virtual_pointer_t retVal =
nullptr;
430 size_t bufSize = b.get_count();
431 pMapNode_t p{b, bufSize,
false};
433 if (m_pointerMap.empty()) {
434 virtual_pointer_t initialVal{m_baseAddress};
435 m_pointerMap.emplace(initialVal, p);
439 auto lastElemIter = get_insertion_point(bufSize);
441 if (lastElemIter->second.m_free) {
442 lastElemIter->second.m_buffer = b;
443 lastElemIter->second.m_free =
false;
447 if (lastElemIter->second.m_size > bufSize) {
449 auto remainingSize = lastElemIter->second.m_size - bufSize;
450 pMapNode_t p2{b, remainingSize,
true};
453 lastElemIter->second.m_size = bufSize;
456 auto newFreePtr = lastElemIter->first + bufSize;
457 auto freeNode = m_pointerMap.emplace(newFreePtr, p2).first;
458 m_freeList.emplace(freeNode);
461 retVal = lastElemIter->first;
463 size_t lastSize = lastElemIter->second.m_size;
464 retVal = lastElemIter->first + lastSize;
465 m_pointerMap.emplace(retVal, p);
476 typename pointerMap_t::iterator b)
const {
477 return ((
a->first < b->first) && (
a->second <= b->second)) ||
478 ((
a->first < b->first) && (b->second <=
a->second));
484 pointerMap_t m_pointerMap;
488 std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
500 inline void PointerMapper::remove_pointer<false>(
const virtual_pointer_t ptr) {
501 if (is_nullptr(ptr)) {
504 m_pointerMap.erase(this->get_node(ptr));
514 inline void *SYCLmalloc(
size_t size, PointerMapper &pMap) {
519 using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
520 auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{
size}));
522 return static_cast<void *
>(thePointer);
532 template <
bool ReUse = true,
typename Po
interMapper>
533 inline void SYCLfree(
void *ptr, PointerMapper &pMap) {
534 pMap.template remove_pointer<ReUse>(ptr);
540 template <
typename Po
interMapper>
541 inline void SYCLfreeAll(PointerMapper &pMap) {
545 template <cl::sycl::access::mode AcMd,
typename T>
547 static const auto global_access = cl::sycl::access::target::global_buffer;
548 static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
550 typedef scalar_t &ref_t;
554 typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
557 typedef RangeAccess<AcMd, T> self_t;
561 : access_(access), offset_(offset), virtual_ptr_(virtual_ptr) {}
563 RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
564 cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
565 : access_{accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
568 RangeAccess(std::nullptr_t) : RangeAccess() {}
571 return (access_.get_pointer().get() + offset_);
573 template <
typename Index>
578 template <
typename Index>
580 return self_t(access_, offset_ + offset, virtual_ptr_);
582 template <
typename Index>
584 return self_t(access_, offset_ - offset, virtual_ptr_);
586 template <
typename Index>
594 const RangeAccess &lhs, std::nullptr_t) {
595 return ((lhs.virtual_ptr_ == -1));
598 const RangeAccess &lhs, std::nullptr_t i) {
604 std::nullptr_t,
const RangeAccess &rhs) {
605 return ((rhs.virtual_ptr_ == -1));
608 std::nullptr_t i,
const RangeAccess &rhs) {
620 self_t temp_iterator(*
this);
622 return temp_iterator;
626 return (access_.get_count() - offset_);
638 return *get_pointer();
642 return *get_pointer();
648 return *(get_pointer() + x);
652 return *(get_pointer() + x);
656 return reinterpret_cast<scalar_t *
>(virtual_ptr_ +
657 (offset_ *
sizeof(scalar_t)));
661 return (virtual_ptr_ != -1);
665 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
669 operator RangeAccess<AcMd, const T>()
const {
670 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
674 cl::sycl::handler &cgh)
const {
675 cgh.require(access_);
684 template <cl::sycl::access::mode AcMd,
typename T>
685 struct RangeAccess<AcMd,
const T> : RangeAccess<AcMd, T> {
686 typedef RangeAccess<AcMd, T> Base;
bool operator!=(const json_pointer< RefStringTypeLhs > &lhs, const json_pointer< RefStringTypeRhs > &rhs) noexcept
Definition: json.hpp:15536
internal::enable_if< internal::valid_indexed_view_overload< RowIndices, ColIndices >::value &&internal::traits< typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::ReturnAsIndexedView, typename EIGEN_INDEXED_VIEW_METHOD_TYPE< RowIndices, ColIndices >::type >::type operator()(const RowIndices &rowIndices, const ColIndices &colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST
Definition: IndexedViewMethods.h:73
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:1076
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:976
#define eigen_assert(x)
Definition: Macros.h:1037
#define EIGEN_STRONG_INLINE
Definition: Macros.h:917
#define EIGEN_THROW_X(X)
Definition: Macros.h:1403
v1d & operator+=(v1d &a, cv1d &b)
Definition: Tools.cpp:29
bool operator==(const cp< d > &a, const cp< d > &b)
Contact equivalence is based solely on the index of objects in contact i and j.
Definition: ContactList.h:291
v1d & operator-=(v1d &a, cv1d &b)
Definition: Tools.cpp:26
bool operator<(const cp< d > &a, const cp< d > &b)
The contact list is order in increasing order of index i, and for two identical i in increasing order...
Definition: ContactList.h:289
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:230
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:227
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16 &a)
Definition: BFloat16.h:200
EIGEN_CONSTEXPR Index size(const T &x)
Definition: Meta.h:479
static EIGEN_DEPRECATED const end_t end
Definition: IndexedViewHelper.h:181
Namespace containing all symbols from the Eigen library.
Definition: LDLT.h:16
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74
EIGEN_DEVICE_FUNC const Product< MatrixDerived, PermutationDerived, AliasFreeProduct > operator*(const MatrixBase< MatrixDerived > &matrix, const PermutationBase< PermutationDerived > &permutation)
Definition: PermutationMatrix.h:515
EIGEN_STRONG_INLINE const CwiseBinaryOp< internal::scalar_difference_op< typename DenseDerived::Scalar, typename SparseDerived::Scalar >, const DenseDerived, const SparseDerived > operator-(const MatrixBase< DenseDerived > &a, const SparseMatrixBase< SparseDerived > &b)
Definition: SparseCwiseBinaryOp.h:708
EIGEN_STRONG_INLINE const CwiseBinaryOp< internal::scalar_sum_op< typename DenseDerived::Scalar, typename SparseDerived::Scalar >, const DenseDerived, const SparseDerived > operator+(const MatrixBase< DenseDerived > &a, const SparseMatrixBase< SparseDerived > &b)
Definition: SparseCwiseBinaryOp.h:694
typename T::pointer pointer_t
Definition: json.hpp:3640
Definition: document.h:416
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1181
unsigned char uint8_t
Definition: stdint.h:124
_W64 signed int intptr_t
Definition: stdint.h:164
#define const
Definition: zconf.h:233