@@ -329,10 +329,7 @@ func transformType(
329
329
let text = name. text
330
330
let isRaw = isRawPointerType ( text: text)
331
331
if isRaw && !isSizedBy {
332
- throw DiagnosticError ( " raw pointers only supported for SizedBy " , node: name)
333
- }
334
- if !isRaw && isSizedBy {
335
- throw DiagnosticError ( " SizedBy only supported for raw pointers " , node: name)
332
+ throw DiagnosticError ( " void pointers not supported for countedBy " , node: name)
336
333
}
337
334
338
335
guard let kind: Mutability = getPointerMutability ( text: text) else {
@@ -375,6 +372,33 @@ func isMutablePointerType(_ type: TypeSyntax) -> Bool {
375
372
}
376
373
}
377
374
375
+ func getPointeeType( _ type: TypeSyntax ) -> TypeSyntax ? {
376
+ if let optType = type. as ( OptionalTypeSyntax . self) {
377
+ return getPointeeType ( optType. wrappedType)
378
+ }
379
+ if let impOptType = type. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
380
+ return getPointeeType ( impOptType. wrappedType)
381
+ }
382
+ if let attrType = type. as ( AttributedTypeSyntax . self) {
383
+ return getPointeeType ( attrType. baseType)
384
+ }
385
+
386
+ guard let idType = type. as ( IdentifierTypeSyntax . self) else {
387
+ return nil
388
+ }
389
+ let text = idType. name. text
390
+ if text != " UnsafePointer " && text != " UnsafeMutablePointer " {
391
+ return nil
392
+ }
393
+ guard let x = idType. genericArgumentClause else {
394
+ return nil
395
+ }
396
+ guard let y = x. arguments. first else {
397
+ return nil
398
+ }
399
+ return y. argument. as ( TypeSyntax . self)
400
+ }
401
+
378
402
protocol BoundsCheckedThunkBuilder {
379
403
func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax
380
404
// buildBasicBoundsChecks creates a variable with the same name as the parameter it replaced,
@@ -648,6 +672,7 @@ extension PointerBoundsThunkBuilder {
648
672
return try transformType ( oldType, generateSpan, isSizedBy, isParameter)
649
673
}
650
674
}
675
+
651
676
var countLabel : String {
652
677
return isSizedBy && generateSpan ? " byteCount " : " count "
653
678
}
@@ -826,7 +851,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
826
851
var args = argOverrides
827
852
let argExpr = ExprSyntax ( " \( unwrappedName) .baseAddress " )
828
853
assert ( args [ index] == nil )
829
- args [ index] = try castPointerToOpaquePointer ( unwrapIfNonnullable ( argExpr) )
854
+ args [ index] = try castPointerToTargetType ( unwrapIfNonnullable ( argExpr) )
830
855
let call = try base. buildFunctionCall ( args)
831
856
let ptrRef = unwrapIfNullable ( " \( name) " )
832
857
@@ -871,11 +896,16 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
871
896
return type
872
897
}
873
898
874
- func castPointerToOpaquePointer ( _ baseAddress: ExprSyntax ) throws -> ExprSyntax {
899
+ func castPointerToTargetType ( _ baseAddress: ExprSyntax ) throws -> ExprSyntax {
875
900
let type = peelOptionalType ( getParam ( signature, index) . type)
876
901
if type. canRepresentBasicType ( type: OpaquePointer . self) {
877
902
return ExprSyntax ( " OpaquePointer( \( baseAddress) ) " )
878
903
}
904
+ if isSizedBy {
905
+ if let pointeeType = getPointeeType ( type) {
906
+ return " \( baseAddress) .assumingMemoryBound(to: \( pointeeType) .self) "
907
+ }
908
+ }
879
909
return baseAddress
880
910
}
881
911
@@ -907,7 +937,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
907
937
return unwrappedCall
908
938
}
909
939
910
- args [ index] = try castPointerToOpaquePointer ( getPointerArg ( ) )
940
+ args [ index] = try castPointerToTargetType ( getPointerArg ( ) )
911
941
return try base. buildFunctionCall ( args)
912
942
}
913
943
}
0 commit comments