|
| 1 | +/// Multi basic block VJP, pullback not accepting branch tracing enum argument. |
| 2 | + |
| 3 | +// REQUIRES: executable_test |
| 4 | + |
| 5 | +// RUN: %empty-directory(%t) |
| 6 | +// RUN: %target-build-swift %s -o %t/none.out -Onone |
| 7 | +// RUN: %target-build-swift %s -o %t/opt.out -O |
| 8 | +// RUN: %target-run %t/none.out |
| 9 | +// RUN: %target-run %t/opt.out |
| 10 | + |
| 11 | +// RUN: %target-swift-frontend -emit-sil %s -O -o %t/out.sil |
| 12 | +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK1 |
| 13 | +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK2 |
| 14 | +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK3 |
| 15 | +// RUN: cat %t/out.sil | %FileCheck %s --check-prefix=CHECK4 |
| 16 | + |
| 17 | +import DifferentiationUnittest |
| 18 | +import StdlibUnittest |
| 19 | + |
| 20 | +var AutoDiffClosureSpecMultiBBNoBTETests = TestSuite("AutoDiffClosureSpecMultiBBNoBTE") |
| 21 | + |
| 22 | +typealias FloatArrayTan = Array<Float>.TangentVector |
| 23 | + |
| 24 | +AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test1") { |
| 25 | + // CHECK1-LABEL: {{^}}// reverse-mode derivative of sumFirstThreeConcatenating1 #1 (_:_:) |
| 26 | + // CHECK1-NEXT: sil private @$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJrSSpSr : $@convention(thin) (@guaranteed Array<Float>, @guaranteed Array<Float>) -> (Float, @owned @callee_guaranteed (Float) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView)) { |
| 27 | + // CHECK1: %[[#E52:]] = function_ref @$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAkCSiAkCSiTf1nnccc_n : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array<Float>.DifferentiableView) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView), @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView) |
| 28 | + // CHECK1: %[[#E53:]] = partial_apply [callee_guaranteed] %[[#E52]](%[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (@guaranteed Array<Float>.DifferentiableView) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView), @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView) |
| 29 | + // CHECK1: %[[#E55:]] = tuple (%[[#]], %[[#E53]]) |
| 30 | + // CHECK1: return %[[#E55]] |
| 31 | + // CHECK1: } // end sil function '$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJrSSpSr' |
| 32 | + |
| 33 | + // CHECK1-NONE: {{^}}// pullback of sumFirstThreeConcatenating1 |
| 34 | + // CHECK1: {{^}}// specialized pullback of sumFirstThreeConcatenating1 |
| 35 | + // CHECK1: sil private @$s3outyycfU_27sumFirstThreeConcatenating1L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAkCSiAkCSiTf1nnccc_nTf4ngnnnnnn_n : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@guaranteed Array<Float>.DifferentiableView) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView), @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView) { |
| 36 | + func sumFirstThreeConcatenating1(_ a: [Float], _ b: [Float]) -> Float { |
| 37 | + let c = a + b |
| 38 | + return c[0] + c[1] + c[2] |
| 39 | + } |
| 40 | + |
| 41 | + expectEqual( |
| 42 | + (.init([1, 1]), .init([1, 0])), |
| 43 | + gradient(at: [0, 0], [0, 0], of: sumFirstThreeConcatenating1)) |
| 44 | + expectEqual( |
| 45 | + (.init([1, 1, 1, 0]), .init([0, 0])), |
| 46 | + gradient(at: [0, 0, 0, 0], [0, 0], of: sumFirstThreeConcatenating1)) |
| 47 | + expectEqual( |
| 48 | + (.init([]), .init([1, 1, 1, 0])), |
| 49 | + gradient(at: [], [0, 0, 0, 0], of: sumFirstThreeConcatenating1)) |
| 50 | +} |
| 51 | + |
| 52 | +AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test2") { |
| 53 | + // CHECK2-LABEL: {{^}}// reverse-mode derivative of sumFirstThreeConcatenating2 #1 (_:_:) |
| 54 | + // CHECK2-NEXT: sil private @$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJrSSpSr : $@convention(thin) (@guaranteed Array<Float>, @guaranteed Array<Float>) -> (Float, @owned @callee_guaranteed (Float) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView)) { |
| 55 | + // CHECK2: %[[#E52:]] = function_ref @$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E10_v25Appendyyt5value_SaA2aBRzlmno55Vy13TangentVectorQz_GAIzc8pullbacktSayxGz_AKtFZA2IzcfU_G4_Tg5Si0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAlCSiAlCSiTf1ncccc_n : $@convention(thin) (Float, Int, @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView) |
| 56 | + // CHECK2: %[[#E53:]] = partial_apply [callee_guaranteed] %[[#E52]](%[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]], %[[#]]) : $@convention(thin) (Float, Int, @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView) |
| 57 | + // CHECK2: %[[#E55:]] = tuple (%[[#]], %[[#E53]]) |
| 58 | + // CHECK2: return %[[#E55]] |
| 59 | + // CHECK2: } // end sil function '$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJrSSpSr' |
| 60 | + |
| 61 | + // CHECK2-NONE: {{^}}// pullback of sumFirstThreeConcatenating2 |
| 62 | + // CHECK2: {{^}}// specialized pullback of sumFirstThreeConcatenating2 |
| 63 | + // CHECK2: sil private @$s3outyycfU0_27sumFirstThreeConcatenating2L_ySfSaySfG_ACtFTJpSSpSr055$sSfSa16_DifferentiationAA14DifferentiableRzlE0B4ViewVyg8_GIegno_G10AEIegyo_TRSfSa01_I0AE0K0RzlE0kO0VySf_GIegno_ADSfAIIegno_0f5Sf16_i26E7_vjpAdd3lhs3rhsSf5value_g17_SftSfc8pullbacktg1_y5FZSf_Y6SfcfU_ADSfAIIegno_AJTf1nnccccc_n0fh1_ijkl4E10_v25Appendyyt5value_SaA2aBRzlmno55Vy13TangentVectorQz_GAIzc8pullbacktSayxGz_AKtFZA2IzcfU_G4_Tg5Si0fh1_ijkl4E13_v32Subscript5indexx5value_SaA2aBRzlmnO59Vy13TangentVectorQz_GAIc8pullbacktSi_tFAKL_yAjiaBRzlFSf_TG5ACSiAlCSiAlCSiTf1ncccc_n : $@convention(thin) (Float, Int, @owned Array<Float>, Int, @owned Array<Float>, Int, @owned Array<Float>, Int) -> (@owned Array<Float>.DifferentiableView, @owned Array<Float>.DifferentiableView) { |
| 64 | + |
| 65 | + func sumFirstThreeConcatenating2(_ a: [Float], _ b: [Float]) -> Float { |
| 66 | + var c = a |
| 67 | + c += b |
| 68 | + return c[0] + c[1] + c[2] |
| 69 | + } |
| 70 | + |
| 71 | + expectEqual( |
| 72 | + (.init([1, 1]), .init([1, 0])), |
| 73 | + gradient(at: [0, 0], [0, 0], of: sumFirstThreeConcatenating2)) |
| 74 | + expectEqual( |
| 75 | + (.init([1, 1, 1, 0]), .init([0, 0])), |
| 76 | + gradient(at: [0, 0, 0, 0], [0, 0], of: sumFirstThreeConcatenating2)) |
| 77 | + expectEqual( |
| 78 | + (.init([]), .init([1, 1, 1, 0])), |
| 79 | + gradient(at: [], [0, 0, 0, 0], of: sumFirstThreeConcatenating2)) |
| 80 | +} |
| 81 | + |
| 82 | +AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test3") { |
| 83 | + @propertyWrapper |
| 84 | + enum Wrapper { |
| 85 | + case case1(Float) |
| 86 | + case case2(Float) |
| 87 | + |
| 88 | + init(wrappedValue: Float) { |
| 89 | + self = .case1(wrappedValue) |
| 90 | + } |
| 91 | + |
| 92 | + var wrappedValue: Float { |
| 93 | + get { |
| 94 | + switch self { |
| 95 | + case .case1(let val): |
| 96 | + return val |
| 97 | + case .case2(let val): |
| 98 | + return val * 2 |
| 99 | + } |
| 100 | + } |
| 101 | + set { |
| 102 | + self = .case2(wrappedValue) |
| 103 | + } |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + struct RealPropertyWrappers: Differentiable { |
| 108 | + @Wrapper var x: Float = 3 |
| 109 | + var y: Float = 4 |
| 110 | + } |
| 111 | + |
| 112 | + // CHECK3: {{^}}// reverse-mode derivative of multiply #1 (_:) |
| 113 | + // CHECK3-NEXT: sil private @$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJrSpSr : $@convention(thin) (RealPropertyWrappers) -> (Float, @owned @callee_guaranteed (Float) -> RealPropertyWrappers.TangentVector) { |
| 114 | + // CHECK3: %[[#A22:]] = function_ref @$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktm1_n5FZSf_N6SfcfU_S2fTf1nnc_n015$s3outyycfU1_20cdE16L_V1xSfvgTJpSpSrTf1ncnn_n : $@convention(thin) (Float, Float, Float) -> RealPropertyWrappers.TangentVector |
| 115 | + // CHECK3: %[[#A23:]] = partial_apply [callee_guaranteed] %[[#A22]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> RealPropertyWrappers.TangentVector |
| 116 | + // CHECK3: %[[#A24:]] = tuple (%[[#]], %[[#A23]]) |
| 117 | + // CHECK3: return %[[#A24]] |
| 118 | + // CHECK3: } // end sil function '$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJrSpSr' |
| 119 | + |
| 120 | + // CHECK3-NONE: {{^}}// pullback of multiply |
| 121 | + // CHECK3: {{^}}// specialized pullback of multiply |
| 122 | + // CHECK3: sil private @$s3outyycfU1_8multiplyL_ySfAAyycfU1_20RealPropertyWrappersL_VFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktm1_n5FZSf_N6SfcfU_S2fTf1nnc_n015$s3outyycfU1_20cdE16L_V1xSfvgTJpSpSrTf1ncnn_n : $@convention(thin) (Float, Float, Float) -> RealPropertyWrappers.TangentVector { |
| 123 | + |
| 124 | + @differentiable(reverse) |
| 125 | + func multiply(_ s: RealPropertyWrappers) -> Float { |
| 126 | + return s.x * s.y |
| 127 | + } |
| 128 | + |
| 129 | + expectEqual( |
| 130 | + .init(x: 4, y: 3), |
| 131 | + gradient(at: RealPropertyWrappers(x: 3, y: 4), of: multiply)) |
| 132 | +} |
| 133 | + |
| 134 | +AutoDiffClosureSpecMultiBBNoBTETests.testWithLeakChecking("Test4") { |
| 135 | + struct Class: Differentiable { |
| 136 | + var stored: Float |
| 137 | + var optional: Float? |
| 138 | + |
| 139 | + init(stored: Float, optional: Float?) { |
| 140 | + self.stored = stored |
| 141 | + self.optional = optional |
| 142 | + } |
| 143 | + |
| 144 | + @differentiable(reverse) |
| 145 | + func method() -> Float { |
| 146 | + let c: Class |
| 147 | + do { |
| 148 | + let tmp = Class(stored: 1 * stored, optional: optional) |
| 149 | + let tuple = (tmp, tmp) |
| 150 | + c = tuple.0 |
| 151 | + } |
| 152 | + if let x = c.optional { |
| 153 | + return x * c.stored |
| 154 | + } |
| 155 | + return 1 * c.stored |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + // CHECK4-LABEL: {{^}}// reverse-mode derivative of methodWrapper #1 (_:) |
| 160 | + // CHECK4-NEXT: sil private @$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJrSpSr : $@convention(thin) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) { |
| 161 | + // CHECK4: %[[#C39:]] = function_ref @$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJpSpSr014$s3outyycfU2_5D21L_V6methodSfyFTJpSpSrAA05_AD__ef2_5d2L_gH24F_bb3__Pred__src_0_wrt_033_E588B908471A5F020CF23EC392ADD7D3LLOTf1nc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU2_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 162 | + // CHECK4: %[[#C40:]] = partial_apply [callee_guaranteed] %[[#C39]](%[[#]]) : $@convention(thin) (Float, @owned _AD__$s3outyycfU2_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector |
| 163 | + // CHECK4: %[[#C42:]] = tuple (%[[#]], %[[#C40]]) |
| 164 | + // CHECK4: return %[[#C42]] |
| 165 | + // CHECK4: } // end sil function '$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJrSpSr' |
| 166 | + |
| 167 | + /// TODO: even though branch tracing enum is not passed to top-level pullback |
| 168 | + /// directly, it is captured by one of the closures which was specialized. |
| 169 | + /// Because of that, this enum argument is now an argument of specialized top-level |
| 170 | + /// pullback. Specializing closures passed as payload tuple elements of the enum |
| 171 | + /// is currently not supported. |
| 172 | + |
| 173 | + // CHECK4-NONE: {{^}}// pullback of methodWrapper |
| 174 | + // CHECK4: {{^}}// specialized pullback of methodWrapper |
| 175 | + // CHECK4: sil private @$s3outyycfU2_13methodWrapperL_ySfAAyycfU2_5ClassL_VFTJpSpSr014$s3outyycfU2_5D21L_V6methodSfyFTJpSpSrAA05_AD__ef2_5d2L_gH24F_bb3__Pred__src_0_wrt_033_E588B908471A5F020CF23EC392ADD7D3LLOTf1nc_n : $@convention(thin) (Float, @owned _AD__$s3outyycfU2_5ClassL_V6methodSfyF_bb3__Pred__src_0_wrt_0) -> Class.TangentVector { |
| 176 | + |
| 177 | + @differentiable(reverse) |
| 178 | + func methodWrapper(_ x: Class) -> Float { |
| 179 | + x.method() |
| 180 | + } |
| 181 | + |
| 182 | + expectEqual( |
| 183 | + valueWithGradient(at: Class(stored: 3, optional: 4), of: methodWrapper), |
| 184 | + (12, .init(stored: 4, optional: .init(3)))) |
| 185 | + expectEqual( |
| 186 | + valueWithGradient(at: Class(stored: 3, optional: nil), of: methodWrapper), |
| 187 | + (3, .init(stored: 1, optional: .init(0)))) |
| 188 | +} |
| 189 | + |
| 190 | +runAllTests() |
0 commit comments