Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 46 additions & 21 deletions gpt_oss/metal/source/metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,19 @@ enum gptoss_status gptoss_metal_library_create_default(
enum gptoss_status status = gptoss_status_success;
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
id<MTLLibrary> library_obj = nil;
NSError* error_obj = nil;
NSString* error_string_obj = nil;
NSAutoreleasePool* autorelease_pool = nil;
dispatch_data_t library_blob = NULL;

unsigned long library_size = 0;
uint8_t* library_data = getsectiondata(&__dso_handle, "__METAL", "__shaders", &library_size);
if (library_data != NULL) {
library_blob = dispatch_data_create(library_data, library_size, NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT);

autorelease_pool = [[NSAutoreleasePool alloc] init];
NSError* error_obj = nil;
library_obj = [device_obj newLibraryWithData:library_blob error:&error_obj];
if (library_obj == nil) {
error_string_obj = [error_obj localizedDescription];
GPTOSS_LOG_ERROR("failed to create Metal library: %s", [error_string_obj UTF8String]);
GPTOSS_LOG_ERROR("failed to create Metal library: %s", [[error_obj localizedDescription] UTF8String]);
status = gptoss_status_unsupported_system;
goto cleanup;
}
Expand All @@ -129,11 +130,8 @@ enum gptoss_status gptoss_metal_library_create_default(
if (library_blob != NULL) {
dispatch_release(library_blob);
}
if (error_string_obj != nil) {
[error_string_obj release];
}
if (error_obj != nil) {
[error_obj release];
if (autorelease_pool != nil) {
[autorelease_pool drain];
}
return status;
}
Expand All @@ -154,26 +152,50 @@ enum gptoss_status gptoss_metal_function_create(
const char* name,
struct gptoss_metal_function* function_out)
{
NSString* name_obj = nil;
NSError* error_obj = nil;
NSString* error_string_obj = nil;
__block NSString* error_string_obj = nil;
id<MTLFunction> function_obj = nil;
MTLComputePipelineDescriptor* pipeline_descriptor_obj = nil;
__block id<MTLComputePipelineState> pipeline_state_obj = nil;
dispatch_semaphore_t pipeline_build_semaphore = NULL;
enum gptoss_status status = gptoss_status_success;

NSAutoreleasePool* autorelease_pool = [[NSAutoreleasePool alloc] init];
id<MTLLibrary> library_obj = (id<MTLLibrary>) library->object;
name_obj = [NSString stringWithUTF8String:name];
NSString* name_obj = [NSString stringWithUTF8String:name];
function_obj = [library_obj newFunctionWithName:name_obj];
if (function_obj == nil) {
GPTOSS_LOG_ERROR("failed to create Metal function %s", name);
status = gptoss_status_unsupported_system;
goto cleanup;
}
id<MTLDevice> device_obj = [library_obj device];
id<MTLComputePipelineState> pipeline_state_obj = [device_obj newComputePipelineStateWithFunction:function_obj error:&error_obj];
pipeline_descriptor_obj = [[MTLComputePipelineDescriptor alloc] init];
[pipeline_descriptor_obj setComputeFunction:function_obj];
[pipeline_descriptor_obj setThreadGroupSizeIsMultipleOfThreadExecutionWidth:YES];

pipeline_build_semaphore = dispatch_semaphore_create(/*value=*/0);
[device_obj newComputePipelineStateWithDescriptor:pipeline_descriptor_obj
options:MTLPipelineOptionNone
completionHandler:^(id<MTLComputePipelineState> _Nullable new_state,
MTLComputePipelineReflection* _Nullable reflection,
NSError* _Nullable error_obj) {
if (new_state != nil) {
pipeline_state_obj = [new_state retain];
}
if (error_obj != nil) {
error_string_obj = [[error_obj localizedDescription] copy];
}
dispatch_semaphore_signal(pipeline_build_semaphore);
}];
dispatch_semaphore_wait(pipeline_build_semaphore, DISPATCH_TIME_FOREVER);

if (pipeline_state_obj == nil) {
error_string_obj = [error_obj localizedDescription];
const char* error_string = "unknown error";
if (error_string_obj != nil) {
error_string = [error_string_obj UTF8String];
}
GPTOSS_LOG_ERROR("failed to create Metal compute pipeline state for function %s: %s",
name, [error_string_obj UTF8String]);
name, error_string);
status = gptoss_status_unsupported_system;
goto cleanup;
}
Expand All @@ -189,17 +211,20 @@ enum gptoss_status gptoss_metal_function_create(
pipeline_state_obj = nil;

cleanup:
if (name_obj != nil) {
[name_obj release];
}
if (function_obj != nil) {
[function_obj release];
}
if (pipeline_descriptor_obj != nil) {
[pipeline_descriptor_obj release];
}
if (error_string_obj != nil) {
[error_string_obj release];
}
if (error_obj != nil) {
[error_obj release];
if (pipeline_build_semaphore != NULL) {
dispatch_release(pipeline_build_semaphore);
}
if (autorelease_pool != nil) {
[autorelease_pool drain];
}
return status;
}
Expand Down