|
6 | 6 |
|
7 | 7 | #pragma once |
8 | 8 |
|
9 | | -#include <iostream> |
10 | | -#include <stdexcept> |
11 | | - |
12 | | -static void throw_exception(const std::string& str) { |
13 | | - std::cerr << str << std::endl; |
14 | | - throw std::runtime_error(str); |
15 | | -} |
16 | | - |
17 | | -inline void dispatch_block( |
18 | | - [[maybe_unused]] id<MTLCommandQueue> queue, |
19 | | - void (^block)()) { |
20 | | - __block std::optional<std::exception_ptr> block_exception; |
21 | | - try { |
22 | | - block(); |
23 | | - } catch (...) { |
24 | | - block_exception = std::current_exception(); |
25 | | - } |
26 | | - if (block_exception) { |
27 | | - std::rethrow_exception(*block_exception); |
28 | | - } |
29 | | -} |
30 | | - |
31 | | -inline id<MTLDevice> getMetalDevice() { |
32 | | - @autoreleasepool { |
33 | | - NSArray* devices = [MTLCopyAllDevices() autorelease]; |
34 | | - if (devices.count == 0) { |
35 | | - throw_exception("Metal is not supported"); |
36 | | - } |
37 | | - return devices[0]; |
38 | | - } |
39 | | -} |
40 | | - |
41 | | -static id<MTLDevice> MTL_DEVICE = getMetalDevice(); |
42 | | - |
43 | | -static id<MTLLibrary> compileLibraryFromSource( |
44 | | - id<MTLDevice> device, |
45 | | - const std::string& source) { |
46 | | - NSError* error = nil; |
47 | | - MTLCompileOptions* options = [MTLCompileOptions new]; |
48 | | - [options setLanguageVersion:MTLLanguageVersion3_1]; |
49 | | - NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()]; |
50 | | - id<MTLLibrary> library = [device newLibraryWithSource:kernel_source |
51 | | - options:options |
52 | | - error:&error]; |
53 | | - if (library == nil) { |
54 | | - throw_exception( |
55 | | - "Failed to compile: " + std::string(error.description.UTF8String)); |
56 | | - } |
57 | | - return library; |
58 | | -} |
59 | | - |
60 | | -class MetalShaderLibrary { |
61 | | - public: |
62 | | - MetalShaderLibrary(const std::string& src) : shaderSource(src) { |
63 | | - lib = compileLibraryFromSource(device, shaderSource); |
64 | | - } |
65 | | - MetalShaderLibrary(const MetalShaderLibrary&) = delete; |
66 | | - MetalShaderLibrary(MetalShaderLibrary&&) = delete; |
67 | | - |
68 | | - id<MTLComputePipelineState> getPipelineStateForFunc( |
69 | | - const std::string& fname) { |
70 | | - return get_compute_pipeline_state(load_func(fname)); |
71 | | - } |
72 | | - |
73 | | - private: |
74 | | - std::string shaderSource; |
75 | | - id<MTLDevice> device = MTL_DEVICE; |
76 | | - id<MTLLibrary> lib = nil; |
77 | | - |
78 | | - id<MTLFunction> load_func(const std::string& func_name) const { |
79 | | - id<MTLFunction> func = [lib |
80 | | - newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; |
81 | | - if (func == nil) { |
82 | | - throw_exception("Can't get function:" + func_name); |
83 | | - } |
84 | | - return func; |
85 | | - } |
86 | | - |
87 | | - id<MTLComputePipelineState> get_compute_pipeline_state( |
88 | | - id<MTLFunction> func) const { |
89 | | - NSError* error = nil; |
90 | | - auto cpl = [device newComputePipelineStateWithFunction:func error:&error]; |
91 | | - if (cpl == nil) { |
92 | | - throw_exception( |
93 | | - "Failed to construct pipeline state: " + |
94 | | - std::string(error.description.UTF8String)); |
95 | | - } |
96 | | - return cpl; |
97 | | - } |
98 | | -}; |
| 9 | +id<MTLDevice> getMetalDevice(); |
99 | 10 |
|
100 | 11 | class MPSStream { |
101 | 12 | public: |
102 | 13 | MPSStream() { |
103 | | - _commandQueue = [MTL_DEVICE newCommandQueue]; |
| 14 | + _commandQueue = [getMetalDevice() newCommandQueue]; |
104 | 15 | } |
105 | 16 |
|
106 | 17 | ~MPSStream() { |
@@ -136,14 +47,6 @@ class MPSStream { |
136 | 47 | id<MTLComputeCommandEncoder> _commandEncoder = nil; |
137 | 48 | }; |
138 | 49 |
|
139 | | -inline void finalize_block(MPSStream* mpsStream) { |
140 | | - id<MTLCommandEncoder> encoder = mpsStream->commandEncoder(); |
141 | | - id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer(); |
142 | | - [encoder endEncoding]; |
143 | | - [cmdBuffer commit]; |
144 | | - [cmdBuffer waitUntilCompleted]; |
145 | | -} |
146 | | - |
147 | 50 | inline MPSStream* getCurrentMPSStream() { |
148 | 51 | return new MPSStream(); |
149 | 52 | } |
0 commit comments