Skip to content

Commit c681268

Browse files
committed
Add validation-tests for AutoDiff closure spec pass (multi-BB case)
1 parent fec443a commit c681268

File tree

2 files changed

+323
-0
lines changed

2 files changed

+323
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/// Multi basic block VJP, pullback accepting branch tracing enum argument.
2+
3+
// REQUIRES: executable_test
4+
5+
// RUN: %empty-directory(%t)
6+
// RUN: %target-build-swift %s -o %t/none.out -Onone
7+
// RUN: %target-build-swift %s -o %t/opt.out -O
8+
// RUN: %target-run %t/none.out
9+
// RUN: %target-run %t/opt.out
10+
11+
// RUN: %target-swift-frontend -emit-sil %s -O -o %t/out.sil
12+
// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK1
13+
// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK2
14+
// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK3
15+
16+
import DifferentiationUnittest
17+
import StdlibUnittest
18+
19+
var AutoDiffClosureSpecMultiBBBTETests = TestSuite("AutoDiffClosureSpecMultiBBBTE")
20+
21+
AutoDiffClosureSpecMultiBBBTETests.testWithLeakChecking("Test1") {
22+
// CHECK1-LABEL: {{^}}// reverse-mode derivative of mul42 #1 (_:)
23+
// CHECK1-NEXT: sil private @$s3outyycfU_5mul42L_yS2fSgFTJrSpSr : $@convention(thin) (Optional<Float>) -> (Float, @owned @callee_guaranteed (Float) -> Optional<Float>.TangentVector) {
24+
// CHECK1: %[[#A12:]] = function_ref @$s3outyycfU_5mul42L_yS2fSgFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2fTf1nnc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU_5mul42L_yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional<Float>.TangentVector
25+
// CHECK1: %[[#A13:]] = partial_apply [callee_guaranteed] %[[#A12]](%[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, @owned _AD__$s3outyycfU_5mul42L_yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional<Float>.TangentVector
26+
// CHECK1: %[[#A14:]] = tuple (%[[#]], %[[#A13]])
27+
// CHECK1: return %[[#A14]]
28+
// CHECK1: } // end sil function '$s3outyycfU_5mul42L_yS2fSgFTJrSpSr'
29+
30+
// CHECK1-NONE: {{^}}// pullback of mul42
31+
// CHECK1: {{^}}// specialized pullback of mul42
32+
// CHECK1: sil private @$s3outyycfU_5mul42L_yS2fSgFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2fTf1nnc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU_5mul42L_yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional<Float>.TangentVector {
33+
34+
@differentiable(reverse)
35+
func mul42(_ a: Float?) -> Float {
36+
let b = 42 * a!
37+
return b
38+
}
39+
40+
expectEqual((-84, 42), valueWithGradient(at: -2, of: mul42))
41+
expectEqual((0, 42), valueWithGradient(at: 0, of: mul42))
42+
expectEqual((42, 42), valueWithGradient(at: 1, of: mul42))
43+
expectEqual((210, 42), valueWithGradient(at: 5, of: mul42))
44+
}
45+
46+
AutoDiffClosureSpecMultiBBBTETests.testWithLeakChecking("Test2") {
47+
// CHECK2-LABEL: {{^}}// reverse-mode derivative of cond_tuple_var #1 (_:)
48+
// CHECK2-NEXT: sil private @$s3outyycfU0_14cond_tuple_varL_yS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
49+
// CHECK2: %[[#E41:]] = function_ref @$s3outyycfU0_14cond_tuple_varL_yS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktl1_m5FZSf_M6SfcfU_0ef1_g4E12_i16Subtract3lhs3rhsk1_l1_mnl1_mo1_mP2U_ACTf1nnccc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU0_14cond_tuple_varL_yS2fF_bb3__Pred__src_0_wrt_0) -> Float
50+
// CHECK2: %[[#E42:]] = partial_apply [callee_guaranteed] %[[#E41]](%[[#]]) : $@convention(thin) (Float, @owned _AD__$s3outyycfU0_14cond_tuple_varL_yS2fF_bb3__Pred__src_0_wrt_0) -> Float
51+
// CHECK2: %[[#E46:]] = tuple (%[[#]], %[[#E42]])
52+
// CHECK2: return %[[#E46]]
53+
// CHECK2: } // end sil function '$s3outyycfU0_14cond_tuple_varL_yS2fFTJrSpSr'
54+
55+
// CHECK2-NONE: {{^}}// pullback of cond_tuple_var
56+
// CHECK2: {{^}}// specialized pullback of cond_tuple_var
57+
// CHECK2: sil private @$s3outyycfU0_14cond_tuple_varL_yS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktl1_m5FZSf_M6SfcfU_0ef1_g4E12_i16Subtract3lhs3rhsk1_l1_mnl1_mo1_mP2U_ACTf1nnccc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU0_14cond_tuple_varL_yS2fF_bb3__Pred__src_0_wrt_0) -> Float {
58+
59+
func cond_tuple_var(_ x: Float) -> Float {
60+
// Convoluted function returning `x + x`.
61+
var y: (Float, Float) = (x, x)
62+
var z: (Float, Float) = (x + x, x - x)
63+
if x > 0 {
64+
var w = (x, x)
65+
y.0 = w.1
66+
y.1 = w.0
67+
z.0 = z.0 - y.0
68+
z.1 = z.1 + y.0
69+
} else {
70+
z = (1 * x, x)
71+
}
72+
return y.0 + y.1 - z.0 + z.1
73+
}
74+
75+
expectEqual((8, 2), valueWithGradient(at: 4, of: cond_tuple_var))
76+
expectEqual((-20, 2), valueWithGradient(at: -10, of: cond_tuple_var))
77+
expectEqual((-2674, 2), valueWithGradient(at: -1337, of: cond_tuple_var))
78+
}
79+
80+
AutoDiffClosureSpecMultiBBBTETests.testWithLeakChecking("Test3") {
81+
struct Class: Differentiable {
82+
var stored: Float
83+
var optional: Float?
84+
85+
init(stored: Float, optional: Float?) {
86+
self.stored = stored
87+
self.optional = optional
88+
}
89+
90+
// CHECK3-LABEL: {{^}}// reverse-mode derivative of method()
91+
// CHECK3-NEXT: sil private @$s3outyycfU1_5ClassL_V6methodSfyFTJrSpSr : $@convention(method) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
92+
// CHECK3: %[[#C44:]] = function_ref @$s3outyycfU1_5ClassL_V6methodSfyFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktk1_l5FZSf_L6SfcfU_S2fAES2fTf1nncc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU1_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector
93+
// CHECK3: %[[#C45:]] = partial_apply [callee_guaranteed] %[[#C44]](%[[#]], %[[#]], %[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, @owned _AD__$s3outyycfU1_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector
94+
// CHECK3: %[[#C48:]] = tuple (%[[#]], %[[#C45]])
95+
// CHECK3: return %[[#C48]]
96+
// CHECK3: } // end sil function '$s3outyycfU1_5ClassL_V6methodSfyFTJrSpSr'
97+
98+
// CHECK3-NONE: {{^}}// pullback of method
99+
// CHECK3: {{^}}// specialized pullback of method()
100+
// CHECK3: sil private @$s3outyycfU1_5ClassL_V6methodSfyFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktk1_l5FZSf_L6SfcfU_S2fAES2fTf1nncc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU1_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector {
101+
102+
@differentiable(reverse)
103+
func method() -> Float {
104+
let c: Class
105+
do {
106+
let tmp = Class(stored: 1 * stored, optional: optional)
107+
let tuple = (tmp, tmp)
108+
c = tuple.0
109+
}
110+
var ret : Float = 0
111+
if let x = c.optional {
112+
ret = x * c.stored
113+
} else {
114+
ret = 1 * c.stored
115+
}
116+
return 1 * ret * ret
117+
}
118+
}
119+
120+
@differentiable(reverse)
121+
func methodWrapper(_ x: Class) -> Float {
122+
x.method()
123+
}
124+
125+
expectEqual(
126+
valueWithGradient(at: Class(stored: 3, optional: 4), of: methodWrapper),
127+
(144, .init(stored: 96, optional: .init(72))))
128+
expectEqual(
129+
valueWithGradient(at: Class(stored: 3, optional: nil), of: methodWrapper),
130+
(9, .init(stored: 6, optional: .init(0))))
131+
}
132+
133+
runAllTests()
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/// Multi basic block VJP, pullback not accepting branch tracing enum argument.
2+
3+
// REQUIRES: executable_test
4+
5+
// RUN: %empty-directory(%t)
6+
// RUN: %target-build-swift %s -o %t/none.out -Onone
7+
// RUN: %target-build-swift %s -o %t/opt.out -O
8+
// RUN: %target-run %t/none.out
9+
// RUN: %target-run %t/opt.out
10+
11+
// RUN: %target-swift-frontend -emit-sil %s -O -o %t/out.sil
12+
// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK1
13+
// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK2
14+
// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK3
15+
// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK4
16+
17+
import DifferentiationUnittest
18+
import StdlibUnittest
19+
20+
var AutoDiffClosureSpecMultiBBNoBTETests = TestSuite("AutoDiffClosureSpecMultiBBNoBTE")
21+
22+
typealias FloatArrayTan = Array<Float>.TangentVector
23+
24+
AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test1") {
25+
// CHECK1-LABEL: {{^}}// reverse-mode derivative of sumFirstThreeConcatenating1 #1 (_:_:)
26+
// CHECK1-NEXT: sil private @$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJrSSpSr : $@convention(thin) (@guaranteed Array<Float>, @guaranteed Array<Float>) -> (Float, @owned @callee_guaranteed (Float) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView)) {
27+
// CHECK1: %[[#E52:]] = function_ref @$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAkCSiAkCSiTf1nnccc_n : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array<Float>.DifferentiableView) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView), @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView)
28+
// CHECK1: %[[#E53:]] = partial_apply [callee_guaranteed] %[[#E52]](%[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array<Float>.DifferentiableView) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView), @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView)
29+
// CHECK1: %[[#E55:]] = tuple (%[[#]], %[[#E53]])
30+
// CHECK1: return %[[#E55]]
31+
// CHECK1: } // end sil function '$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJrSSpSr'
32+
33+
// CHECK1-NONE: {{^}}// pullback of sumFirstThreeConcatenating1
34+
// CHECK1: {{^}}// specialized pullback of sumFirstThreeConcatenating1
35+
// CHECK1: sil private @$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAkCSiAkCSiTf1nnccc_nTf4ngnnnnnn_n : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@guaranteed Array<Float>.DifferentiableView) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView), @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView) {
36+
func sumFirstThreeConcatenating1(_ a: [Float], _ b: [Float]) -> Float {
37+
let c = a + b
38+
return c[0] + c[1] + c[2]
39+
}
40+
41+
expectEqual(
42+
(.init([1, 1]), .init([1, 0])),
43+
gradient(at: [0, 0], [0, 0], of: sumFirstThreeConcatenating1))
44+
expectEqual(
45+
(.init([1, 1, 1, 0]), .init([0, 0])),
46+
gradient(at: [0, 0, 0, 0], [0, 0], of: sumFirstThreeConcatenating1))
47+
expectEqual(
48+
(.init([]), .init([1, 1, 1, 0])),
49+
gradient(at: [], [0, 0, 0, 0], of: sumFirstThreeConcatenating1))
50+
}
51+
52+
AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test2") {
53+
// CHECK2-LABEL: {{^}}// reverse-mode derivative of sumFirstThreeConcatenating2 #1 (_:_:)
54+
// CHECK2-NEXT: sil private @$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJrSSpSr : $@convention(thin) (@guaranteed Array<Float>, @guaranteed Array<Float>) -> (Float, @owned @callee_guaranteed (Float) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView)) {
55+
// CHECK2: %[[#E52:]] = function_ref @$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E10_v25Appendyyt5value_SaA2aBRzlmno55Vy13TangentVectorQz_GAIzc8pullbacktSayxGz_AKtFZA2IzcfU_G4_Tg5Si0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAlCSiAlCSiTf1ncccc_n : $@convention(thin) (Float, Int, @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView)
56+
// CHECK2: %[[#E53:]] = partial_apply [callee_guaranteed] %[[#E52]](%[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, Int, @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView)
57+
// CHECK2: %[[#E55:]] = tuple (%[[#]], %[[#E53]])
58+
// CHECK2: return %[[#E55]]
59+
// CHECK2: } // end sil function '$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJrSSpSr'
60+
61+
// CHECK2-NONE: {{^}}// pullback of sumFirstThreeConcatenating2
62+
// CHECK2: {{^}}// specialized pullback of sumFirstThreeConcatenating2
63+
// CHECK2: sil private @$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E10_v25Appendyyt5value_SaA2aBRzlmno55Vy13TangentVectorQz_GAIzc8pullbacktSayxGz_AKtFZA2IzcfU_G4_Tg5Si0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAlCSiAlCSiTf1ncccc_n : $@convention(thin) (Float, Int, @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView) {
64+
65+
func sumFirstThreeConcatenating2(_ a: [Float], _ b: [Float]) -> Float {
66+
var c = a
67+
c += b
68+
return c[0] + c[1] + c[2]
69+
}
70+
71+
expectEqual(
72+
(.init([1, 1]), .init([1, 0])),
73+
gradient(at: [0, 0], [0, 0], of: sumFirstThreeConcatenating2))
74+
expectEqual(
75+
(.init([1, 1, 1, 0]), .init([0, 0])),
76+
gradient(at: [0, 0, 0, 0], [0, 0], of: sumFirstThreeConcatenating2))
77+
expectEqual(
78+
(.init([]), .init([1, 1, 1, 0])),
79+
gradient(at: [], [0, 0, 0, 0], of: sumFirstThreeConcatenating2))
80+
}
81+
82+
AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test3") {
83+
@propertyWrapper
84+
enum Wrapper {
85+
case case1(Float)
86+
case case2(Float)
87+
88+
init(wrappedValue: Float) {
89+
self = .case1(wrappedValue)
90+
}
91+
92+
var wrappedValue: Float {
93+
get {
94+
switch self {
95+
case .case1(let val):
96+
return val
97+
case .case2(let val):
98+
return val * 2
99+
}
100+
}
101+
set {
102+
self = .case2(wrappedValue)
103+
}
104+
}
105+
}
106+
107+
struct RealPropertyWrappers: Differentiable {
108+
@Wrapper var x: Float = 3
109+
var y: Float = 4
110+
}
111+
112+
// CHECK3: {{^}}// reverse-mode derivative of multiply #1 (_:)
113+
// CHECK3-NEXT: sil private @$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJrSpSr : $@convention(thin) (RealPropertyWrappers) -> (Float, @owned @callee_guaranteed (Float) -> RealPropertyWrappers.TangentVector) {
114+
// CHECK3: %[[#A22:]] = function_ref @$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktm1_n5FZSf_N6SfcfU_S2fTf1nnc_n015$s3outyycfU1_20cdE16L_V1xSfvgTJpSpSrTf1ncnn_n : $@convention(thin) (Float, Float, Float) -> RealPropertyWrappers.TangentVector
115+
// CHECK3: %[[#A23:]] = partial_apply [callee_guaranteed] %[[#A22]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> RealPropertyWrappers.TangentVector
116+
// CHECK3: %[[#A24:]] = tuple (%[[#]], %[[#A23]])
117+
// CHECK3: return %[[#A24]]
118+
// CHECK3: } // end sil function '$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJrSpSr'
119+
120+
// CHECK3-NONE: {{^}}// pullback of multiply
121+
// CHECK3: {{^}}// specialized pullback of multiply
122+
// CHECK3: sil private @$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktm1_n5FZSf_N6SfcfU_S2fTf1nnc_n015$s3outyycfU1_20cdE16L_V1xSfvgTJpSpSrTf1ncnn_n : $@convention(thin) (Float, Float, Float) -> RealPropertyWrappers.TangentVector {
123+
124+
@differentiable(reverse)
125+
func multiply(_ s: RealPropertyWrappers) -> Float {
126+
return s.x * s.y
127+
}
128+
129+
expectEqual(
130+
.init(x: 4, y: 3),
131+
gradient(at: RealPropertyWrappers(x: 3, y: 4), of: multiply))
132+
}
133+
134+
AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test4") {
135+
struct Class: Differentiable {
136+
var stored: Float
137+
var optional: Float?
138+
139+
init(stored: Float, optional: Float?) {
140+
self.stored = stored
141+
self.optional = optional
142+
}
143+
144+
@differentiable(reverse)
145+
func method() -> Float {
146+
let c: Class
147+
do {
148+
let tmp = Class(stored: 1 * stored, optional: optional)
149+
let tuple = (tmp, tmp)
150+
c = tuple.0
151+
}
152+
if let x = c.optional {
153+
return x * c.stored
154+
}
155+
return 1 * c.stored
156+
}
157+
}
158+
159+
// CHECK4-LABEL: {{^}}// reverse-mode derivative of methodWrapper #1 (_:)
160+
// CHECK4-NEXT: sil private @$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
161+
// CHECK4: %[[#C39:]] = function_ref @$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJpSpSr014$s3outyycfU2_5D21L_V6methodSfyFTJpSpSrAA05_AD__ef2_5d2L_gH24F_bb3__Pred__src_0_wrt_033_E588B908471A5F020CF23EC392ADD7D3LLOTf1nc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU2_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
162+
// CHECK4: %[[#C40:]] = partial_apply [callee_guaranteed] %[[#C39]](%[[#]]) : $@convention(thin) (Float, @owned _AD__$s3outyycfU2_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector
163+
// CHECK4: %[[#C42:]] = tuple (%[[#]], %[[#C40]])
164+
// CHECK4: return %[[#C42]]
165+
// CHECK4: } // end sil function '$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJrSpSr'
166+
167+
/// TODO: even though branch tracing enum is not passed to top-level pullback
168+
/// directly, it is captured by one of the closures which was specialized.
169+
/// Because of that, this enum argument is now an argument of specialized top-level
170+
/// pullback. Specializing closures passed as payload tuple elements of the enum
171+
/// is currently not supported.
172+
173+
// CHECK4-NONE: {{^}}// pullback of methodWrapper
174+
// CHECK4: {{^}}// specialized pullback of methodWrapper
175+
// CHECK4: sil private @$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJpSpSr014$s3outyycfU2_5D21L_V6methodSfyFTJpSpSrAA05_AD__ef2_5d2L_gH24F_bb3__Pred__src_0_wrt_033_E588B908471A5F020CF23EC392ADD7D3LLOTf1nc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU2_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector {
176+
177+
@differentiable(reverse)
178+
func methodWrapper(_ x: Class) -> Float {
179+
x.method()
180+
}
181+
182+
expectEqual(
183+
valueWithGradient(at: Class(stored: 3, optional: 4), of: methodWrapper),
184+
(12, .init(stored: 4, optional: .init(3))))
185+
expectEqual(
186+
valueWithGradient(at: Class(stored: 3, optional: nil), of: methodWrapper),
187+
(3, .init(stored: 1, optional: .init(0))))
188+
}
189+
190+
runAllTests()

0 commit comments

Comments
 (0)