Skip to content

Commit 95d7716

Browse files
authored
Metal: indicate threadgroup is a multiple of simdgroup (#168)
2% speedup on gpt-oss-20b end-to-end sampling
1 parent 864020a commit 95d7716

File tree

1 file changed

+46
-21
lines changed

1 file changed

+46
-21
lines changed

gpt_oss/metal/source/metal.m

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,19 @@ enum gptoss_status gptoss_metal_library_create_default(
9696
enum gptoss_status status = gptoss_status_success;
9797
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
9898
id<MTLLibrary> library_obj = nil;
99-
NSError* error_obj = nil;
100-
NSString* error_string_obj = nil;
99+
NSAutoreleasePool* autorelease_pool = nil;
101100
dispatch_data_t library_blob = NULL;
102101

103102
unsigned long library_size = 0;
104103
uint8_t* library_data = getsectiondata(&__dso_handle, "__METAL", "__shaders", &library_size);
105104
if (library_data != NULL) {
106105
library_blob = dispatch_data_create(library_data, library_size, NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT);
106+
107+
autorelease_pool = [[NSAutoreleasePool alloc] init];
108+
NSError* error_obj = nil;
107109
library_obj = [device_obj newLibraryWithData:library_blob error:&error_obj];
108110
if (library_obj == nil) {
109-
error_string_obj = [error_obj localizedDescription];
110-
GPTOSS_LOG_ERROR("failed to create Metal library: %s", [error_string_obj UTF8String]);
111+
GPTOSS_LOG_ERROR("failed to create Metal library: %s", [[error_obj localizedDescription] UTF8String]);
111112
status = gptoss_status_unsupported_system;
112113
goto cleanup;
113114
}
@@ -129,11 +130,8 @@ enum gptoss_status gptoss_metal_library_create_default(
129130
if (library_blob != NULL) {
130131
dispatch_release(library_blob);
131132
}
132-
if (error_string_obj != nil) {
133-
[error_string_obj release];
134-
}
135-
if (error_obj != nil) {
136-
[error_obj release];
133+
if (autorelease_pool != nil) {
134+
[autorelease_pool drain];
137135
}
138136
return status;
139137
}
@@ -154,26 +152,50 @@ enum gptoss_status gptoss_metal_function_create(
154152
const char* name,
155153
struct gptoss_metal_function* function_out)
156154
{
157-
NSString* name_obj = nil;
158-
NSError* error_obj = nil;
159-
NSString* error_string_obj = nil;
155+
__block NSString* error_string_obj = nil;
160156
id<MTLFunction> function_obj = nil;
157+
MTLComputePipelineDescriptor* pipeline_descriptor_obj = nil;
158+
__block id<MTLComputePipelineState> pipeline_state_obj = nil;
159+
dispatch_semaphore_t pipeline_build_semaphore = NULL;
161160
enum gptoss_status status = gptoss_status_success;
162161

162+
NSAutoreleasePool* autorelease_pool = [[NSAutoreleasePool alloc] init];
163163
id<MTLLibrary> library_obj = (id<MTLLibrary>) library->object;
164-
name_obj = [NSString stringWithUTF8String:name];
164+
NSString* name_obj = [NSString stringWithUTF8String:name];
165165
function_obj = [library_obj newFunctionWithName:name_obj];
166166
if (function_obj == nil) {
167167
GPTOSS_LOG_ERROR("failed to create Metal function %s", name);
168168
status = gptoss_status_unsupported_system;
169169
goto cleanup;
170170
}
171171
id<MTLDevice> device_obj = [library_obj device];
172-
id<MTLComputePipelineState> pipeline_state_obj = [device_obj newComputePipelineStateWithFunction:function_obj error:&error_obj];
172+
pipeline_descriptor_obj = [[MTLComputePipelineDescriptor alloc] init];
173+
[pipeline_descriptor_obj setComputeFunction:function_obj];
174+
[pipeline_descriptor_obj setThreadGroupSizeIsMultipleOfThreadExecutionWidth:YES];
175+
176+
pipeline_build_semaphore = dispatch_semaphore_create(/*value=*/0);
177+
[device_obj newComputePipelineStateWithDescriptor:pipeline_descriptor_obj
178+
options:MTLPipelineOptionNone
179+
completionHandler:^(id<MTLComputePipelineState> _Nullable new_state,
180+
MTLComputePipelineReflection* _Nullable reflection,
181+
NSError* _Nullable error_obj) {
182+
if (new_state != nil) {
183+
pipeline_state_obj = [new_state retain];
184+
}
185+
if (error_obj != nil) {
186+
error_string_obj = [[error_obj localizedDescription] copy];
187+
}
188+
dispatch_semaphore_signal(pipeline_build_semaphore);
189+
}];
190+
dispatch_semaphore_wait(pipeline_build_semaphore, DISPATCH_TIME_FOREVER);
191+
173192
if (pipeline_state_obj == nil) {
174-
error_string_obj = [error_obj localizedDescription];
193+
const char* error_string = "unknown error";
194+
if (error_string_obj != nil) {
195+
error_string = [error_string_obj UTF8String];
196+
}
175197
GPTOSS_LOG_ERROR("failed to create Metal compute pipeline state for function %s: %s",
176-
name, [error_string_obj UTF8String]);
198+
name, error_string);
177199
status = gptoss_status_unsupported_system;
178200
goto cleanup;
179201
}
@@ -189,17 +211,20 @@ enum gptoss_status gptoss_metal_function_create(
189211
pipeline_state_obj = nil;
190212

191213
cleanup:
192-
if (name_obj != nil) {
193-
[name_obj release];
194-
}
195214
if (function_obj != nil) {
196215
[function_obj release];
197216
}
217+
if (pipeline_descriptor_obj != nil) {
218+
[pipeline_descriptor_obj release];
219+
}
198220
if (error_string_obj != nil) {
199221
[error_string_obj release];
200222
}
201-
if (error_obj != nil) {
202-
[error_obj release];
223+
if (pipeline_build_semaphore != NULL) {
224+
dispatch_release(pipeline_build_semaphore);
225+
}
226+
if (autorelease_pool != nil) {
227+
[autorelease_pool drain];
203228
}
204229
return status;
205230
}

0 commit comments

Comments
 (0)