|
7 | 7 | //===-----------------------------------------------------------------===//
|
8 | 8 | #pragma once
|
9 | 9 |
|
| 10 | +#include <unordered_map> |
| 11 | + |
10 | 12 | #include "common.hpp"
|
11 | 13 | #include "device.hpp"
|
12 | 14 | #include "platform.hpp"
|
@@ -93,9 +95,61 @@ struct ur_context_handle_t_ {
|
93 | 95 |
|
94 | 96 | uint32_t getReferenceCount() const noexcept { return RefCount; }
|
95 | 97 |
|
| 98 | + /// We need to keep track of USM mappings in AMD HIP, as certain extra |
| 99 | + /// synchronization *is* actually required for correctness. |
| 100 | + /// During kernel enqueue we must dispatch a prefetch for each kernel argument |
| 101 | + /// that points to a USM mapping to ensure the mapping is correctly |
| 102 | + /// populated on the device (https://github.com/intel/llvm/issues/7252). Thus, |
| 103 | + /// we keep track of mappings in the context, and then check against them just |
| 104 | + /// before the kernel is launched. The stream against which the kernel is |
| 105 | + /// launched is not known until enqueue time, but the USM mappings can happen |
| 106 | + /// at any time. Thus, they are tracked on the context used for the urUSM* |
| 107 | + /// mapping. |
| 108 | + /// |
| 109 | + /// The three utility function are simple wrappers around a mapping from a |
| 110 | + /// pointer to a size. |
| 111 | + void addUSMMapping(void *Ptr, size_t Size) { |
| 112 | + std::lock_guard<std::mutex> Guard(Mutex); |
| 113 | + assert(USMMappings.find(Ptr) == USMMappings.end() && |
| 114 | + "mapping already exists"); |
| 115 | + USMMappings[Ptr] = Size; |
| 116 | + } |
| 117 | + |
| 118 | + void removeUSMMapping(const void *Ptr) { |
| 119 | + std::lock_guard<std::mutex> guard(Mutex); |
| 120 | + auto It = USMMappings.find(Ptr); |
| 121 | + if (It != USMMappings.end()) |
| 122 | + USMMappings.erase(It); |
| 123 | + } |
| 124 | + |
| 125 | + std::pair<const void *, size_t> getUSMMapping(const void *Ptr) { |
| 126 | + std::lock_guard<std::mutex> Guard(Mutex); |
| 127 | + auto It = USMMappings.find(Ptr); |
| 128 | + // The simple case is the fast case... |
| 129 | + if (It != USMMappings.end()) |
| 130 | + return *It; |
| 131 | + |
| 132 | + // ... but in the failure case we have to fall back to a full scan to search |
| 133 | + // for "offset" pointers in case the user passes in the middle of an |
| 134 | + // allocation. We have to do some not-so-ordained-by-the-standard ordered |
| 135 | + // comparisons of pointers here, but it'll work on all platforms we support. |
| 136 | + uintptr_t PtrVal = (uintptr_t)Ptr; |
| 137 | + for (std::pair<const void *, size_t> Pair : USMMappings) { |
| 138 | + uintptr_t BaseAddr = (uintptr_t)Pair.first; |
| 139 | + uintptr_t EndAddr = BaseAddr + Pair.second; |
| 140 | + if (PtrVal > BaseAddr && PtrVal < EndAddr) { |
| 141 | + // If we've found something now, offset *must* be nonzero |
| 142 | + assert(Pair.second); |
| 143 | + return Pair; |
| 144 | + } |
| 145 | + } |
| 146 | + return {nullptr, 0}; |
| 147 | + } |
| 148 | + |
96 | 149 | private:
|
97 | 150 | std::mutex Mutex;
|
98 | 151 | std::vector<deleter_data> ExtendedDeleters;
|
| 152 | + std::unordered_map<const void *, size_t> USMMappings; |
99 | 153 | };
|
100 | 154 |
|
101 | 155 | namespace {
|
|
0 commit comments