From 8b6704f2593f670d124a41abc8a691b5532ec57c Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Thu, 22 Jul 2021 09:19:51 -0500 Subject: [PATCH 01/12] Initial easing of checks for differentiability of inouts with results. --- include/swift/AST/AutoDiff.h | 4 ---- include/swift/AST/DiagnosticsSema.def | 3 --- lib/AST/AutoDiff.cpp | 3 --- lib/AST/Type.cpp | 5 ----- lib/Sema/TypeCheckAttr.cpp | 12 ------------ .../Sema/derivative_attr_type_checking.swift | 9 ++------- .../Sema/differentiable_attr_type_checking.swift | 5 ----- 7 files changed, 2 insertions(+), 39 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 3bcdc5dc0852e..4b692d2e7f51a 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -398,9 +398,6 @@ class DerivativeFunctionTypeError enum class Kind { /// Original function type has no semantic results. NoSemanticResults, - /// Original function type has multiple semantic results. - // TODO(TF-1250): Support function types with multiple semantic results. - MultipleSemanticResults, /// Differentiability parmeter indices are empty. NoDifferentiabilityParameters, /// A differentiability parameter does not conform to `Differentiable`. @@ -429,7 +426,6 @@ class DerivativeFunctionTypeError explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind) : functionType(functionType), kind(kind), value(Value()) { assert(kind == Kind::NoSemanticResults || - kind == Kind::MultipleSemanticResults || kind == Kind::NoDifferentiabilityParameters); }; diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 4c372593c3e96..cb4256ca2d7fc 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -3496,9 +3496,6 @@ NOTE(autodiff_attr_original_decl_not_same_type_context,none, (DescriptiveDeclKind)) ERROR(autodiff_attr_original_void_result,none, "cannot differentiate void function %0", (DeclName)) -ERROR(autodiff_attr_original_multiple_semantic_results,none, - "cannot differentiate functions with both an 'inout' parameter and a " - "result", ()) ERROR(autodiff_attr_result_not_differentiable,none, "can only differentiate functions with results that conform to " "'Differentiable', but %0 does not conform to 'Differentiable'", (Type)) diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 60a49b43e304f..67c7b88a49689 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -395,9 +395,6 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const { case Kind::NoSemanticResults: OS << "has no semantic results ('Void' result)"; break; - case Kind::MultipleSemanticResults: - OS << "has multiple semantic results"; - break; case Kind::NoDifferentiabilityParameters: OS << "has no differentiability parameters"; break; diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 3a869beab1a2c..b3f8108302c00 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -6439,11 +6439,6 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( if (originalResults.empty()) return llvm::make_error( this, DerivativeFunctionTypeError::Kind::NoSemanticResults); - // Error if multiple original semantic results. - // TODO(TF-1250): Support functions with multiple semantic results. - if (originalResults.size() > 1) - return llvm::make_error( - this, DerivativeFunctionTypeError::Kind::MultipleSemanticResults); auto originalResult = originalResults.front(); auto originalResultType = originalResult.type; diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 8061f84362915..827428412a6fe 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -4909,12 +4909,6 @@ bool resolveDifferentiableAttrDifferentiabilityParameters( original->getName()) .highlight(original->getSourceRange()); return; - case DerivativeFunctionTypeError::Kind::MultipleSemanticResults: - diags - .diagnose(attr->getLocation(), - diag::autodiff_attr_original_multiple_semantic_results) - .highlight(original->getSourceRange()); - return; case DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters: diags.diagnose(attr->getLocation(), diag::diff_params_clause_no_inferred_parameters); @@ -5427,12 +5421,6 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) { originalAFD->getName()) .highlight(attr->getOriginalFunctionName().Loc.getSourceRange()); return; - case DerivativeFunctionTypeError::Kind::MultipleSemanticResults: - diags - .diagnose(attr->getLocation(), - diag::autodiff_attr_original_multiple_semantic_results) - .highlight(attr->getOriginalFunctionName().Loc.getSourceRange()); - return; case DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters: diags.diagnose(attr->getLocation(), diag::diff_params_clause_no_inferred_parameters); diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 8e99b4bdab4e0..e626209142147 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -746,13 +746,10 @@ extension ProtocolRequirementDerivative { func multipleSemanticResults(_ x: inout Float) -> Float { return x } -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @derivative(of: multipleSemanticResults) func vjpMultipleSemanticResults(x: inout Float) -> ( - value: Float, pullback: (Float) -> Float -) { - return (multipleSemanticResults(&x), { $0 }) -} + value: Float, pullback: (inout Float) -> Void +) { fatalError() } struct InoutParameters: Differentiable { typealias TangentVector = DummyTangentVector @@ -885,14 +882,12 @@ func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {} extension InoutParameters { func multipleSemanticResults(_ x: inout Float) -> Float { x } - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @derivative(of: multipleSemanticResults) func vjpMultipleSemanticResults(_ x: inout Float) -> ( value: Float, pullback: (inout Float) -> Void ) { fatalError() } func inoutVoid(_ x: Float, _ void: inout Void) -> Float {} - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @derivative(of: inoutVoid) func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> ( value: Float, pullback: (inout Float) -> Void diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index 0cd6fa5b1bdb1..16cfecfd23e1a 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -528,7 +528,6 @@ func two9(x: Float, y: Float) -> Float { func inout1(x: Float, y: inout Float) -> Void { let _ = x + y } -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse, wrt: y) func inout2(x: Float, y: inout Float) -> Float { let _ = x + y @@ -670,11 +669,9 @@ final class FinalClass: Differentiable { @differentiable(reverse, wrt: y) func inoutVoid(x: Float, y: inout Float) {} -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse) func multipleSemanticResults(_ x: inout Float) -> Float { x } -// expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse, wrt: y) func swap(x: inout Float, y: inout Float) {} @@ -687,7 +684,6 @@ extension InoutParameters { @differentiable(reverse) static func staticMethod(_ lhs: inout Self, rhs: Self) {} - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse) static func multipleSemanticResults(_ lhs: inout Self, rhs: Self) -> Self {} } @@ -696,7 +692,6 @@ extension InoutParameters { @differentiable(reverse) mutating func mutatingMethod(_ other: Self) {} - // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} @differentiable(reverse) mutating func mutatingMethod(_ other: Self) -> Self {} } From 2c7672acf2fd0d7b1f211019092d847c67af2e21 Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Sun, 25 Jul 2021 14:07:54 -0500 Subject: [PATCH 02/12] Initial attempt at allowing for multiple results in getAutoDiffDerivativeFunctionLinearMapType(). Fails for certain non-wrt inouts. --- lib/AST/Type.cpp | 94 +++++++++++++------ .../Sema/derivative_attr_type_checking.swift | 14 ++- .../validation-test/simple_math.swift | 26 +++++ 3 files changed, 103 insertions(+), 31 deletions(-) diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index b3f8108302c00..70820b99c3352 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -6432,26 +6432,38 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( getSubsetParameters(parameterIndices, diffParams, /*reverseCurryLevels*/ !makeSelfParamFirst); - // Get the original semantic result type. + // Get the original non-inout semantic result types. SmallVector originalResults; autodiff::getFunctionSemanticResultTypes(this, originalResults); // Error if no original semantic results. if (originalResults.empty()) return llvm::make_error( this, DerivativeFunctionTypeError::Kind::NoSemanticResults); - auto originalResult = originalResults.front(); - auto originalResultType = originalResult.type; - - // Get the original semantic result type's `TangentVector` associated type. - auto resultTan = - originalResultType->getAutoDiffTangentSpace(lookupConformance); - // Error if original semantic result has no tangent space. - if (!resultTan) { + // Accumulate non-inout result tangent spaces. + SmallVector resultTanTypes; + bool hasInoutResult = false; + for (auto i : range(originalResults.size())) { + auto originalResult = originalResults[i]; + if (originalResult.isInout) { + hasInoutResult = true; + continue; + } + auto originalResultType = originalResult.type; + // Get the original semantic result type's `TangentVector` associated type. + auto resultTan = + originalResultType->getAutoDiffTangentSpace(lookupConformance); + if (!resultTan) + continue; + auto resultTanType = resultTan->getType(); + resultTanTypes.push_back(resultTanType); + } + + // Error if no semantic result has a tangent space. + if (resultTanTypes.empty() && !hasInoutResult) { return llvm::make_error( this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult, - std::make_pair(originalResultType, /*index*/ 0)); + std::make_pair(originalResults.front().type, /*index*/ 0)); } - auto resultTanType = resultTan->getType(); // Compute the result linear map function type. FunctionType *linearMapType; @@ -6467,11 +6479,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( // - Original: `(T0, inout T1, ...) -> Void` // - Differential: `(T0.Tan, ...) -> T1.Tan` // - // Case 3: original function has a wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` + // Case 3: original function has wrt `inout` parameters. + // - Original: `(T0, inout T1, ...) -> R` + // - Differential: `(T0.Tan, inout T1.Tan, ...) -> R.Tan` SmallVector differentialParams; - bool hasInoutDiffParameter = false; for (auto i : range(diffParams.size())) { auto diffParam = diffParams[i]; auto paramType = diffParam.getPlainType(); @@ -6486,11 +6497,23 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( } differentialParams.push_back(AnyFunctionType::Param( paramTan->getType(), Identifier(), diffParam.getParameterFlags())); - if (diffParam.isInOut()) - hasInoutDiffParameter = true; } - auto differentialResult = - hasInoutDiffParameter ? Type(ctx.TheEmptyTupleType) : resultTanType; + Type differentialResult; + if (resultTanTypes.empty()) { + differentialResult = ctx.TheEmptyTupleType; + } else if (resultTanTypes.size() == 1) { + differentialResult = resultTanTypes.front(); + } else { + SmallVector differentialResults; + for (auto i : range(resultTanTypes.size())) { + auto resultTanType = resultTanTypes[i]; + auto flags = ParameterTypeFlags().withInOut(false); + differentialResults.push_back( + TupleTypeElt(resultTanType, Identifier(), flags)); + } + differentialResult = TupleType::get(differentialResults, ctx); + } + // FIXME: Verify ExtInfo state is correct, not working by accident. FunctionType::ExtInfo info; linearMapType = @@ -6508,11 +6531,11 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( // - Original: `(T0, inout T1, ...) -> Void` // - Pullback: `(T1.Tan) -> (T0.Tan, ...)` // - // Case 3: original function has a wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)` + // Case 3: original function has wrt `inout` parameters. + // - Original: `(T0, inout T1, ...) -> R` + // - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)` SmallVector pullbackResults; - bool hasInoutDiffParameter = false; + SmallVector inoutParams; for (auto i : range(diffParams.size())) { auto diffParam = diffParams[i]; auto paramType = diffParam.getPlainType(); @@ -6526,7 +6549,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( std::make_pair(paramType, i)); } if (diffParam.isInOut()) { - hasInoutDiffParameter = true; + inoutParams.push_back(diffParam); continue; } pullbackResults.emplace_back(paramTan->getType()); @@ -6539,12 +6562,27 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( } else { pullbackResult = TupleType::get(pullbackResults, ctx); } - auto flags = ParameterTypeFlags().withInOut(hasInoutDiffParameter); - auto pullbackParam = - AnyFunctionType::Param(resultTanType, Identifier(), flags); + // First accumulate non-inout results as pullback parameters. + SmallVector pullbackParams; + for (auto i : range(resultTanTypes.size())) { + auto resultTanType = resultTanTypes[i]; + auto flags = ParameterTypeFlags().withInOut(false); + pullbackParams.push_back(AnyFunctionType::Param( + resultTanType, Identifier(), flags)); + } + // Then append inout parameters. + for (auto i : range(inoutParams.size())) { + auto inoutParam = inoutParams[i]; + auto inoutParamType = inoutParam.getPlainType(); + auto inoutParamTan = + inoutParamType->getAutoDiffTangentSpace(lookupConformance); + auto flags = ParameterTypeFlags().withInOut(true); + pullbackParams.push_back(AnyFunctionType::Param( + inoutParamTan->getType(), Identifier(), flags)); + } // FIXME: Verify ExtInfo state is correct, not working by accident. FunctionType::ExtInfo info; - linearMapType = FunctionType::get({pullbackParam}, pullbackResult, info); + linearMapType = FunctionType::get(pullbackParams, pullbackResult, info); break; } } diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index e626209142147..2fb038dd3fe58 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -748,7 +748,15 @@ func multipleSemanticResults(_ x: inout Float) -> Float { } @derivative(of: multipleSemanticResults) func vjpMultipleSemanticResults(x: inout Float) -> ( - value: Float, pullback: (inout Float) -> Void + value: Float, pullback: (Float, inout Float) -> Void +) { fatalError() } + +func inoutNonDifferentiableResult(_ x: inout Float) -> Int { + return 5 +} +@derivative(of: inoutNonDifferentiableResult) +func vjpInoutNonDifferentiableResult(x: inout Float) -> ( + value: Int, pullback: (inout Float) -> Void ) { fatalError() } struct InoutParameters: Differentiable { @@ -884,13 +892,13 @@ extension InoutParameters { func multipleSemanticResults(_ x: inout Float) -> Float { x } @derivative(of: multipleSemanticResults) func vjpMultipleSemanticResults(_ x: inout Float) -> ( - value: Float, pullback: (inout Float) -> Void + value: Float, pullback: (Float, inout Float) -> Void ) { fatalError() } func inoutVoid(_ x: Float, _ void: inout Void) -> Float {} @derivative(of: inoutVoid) func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> ( - value: Float, pullback: (inout Float) -> Void + value: Float, pullback: (Float) -> Float ) { fatalError() } } diff --git a/test/AutoDiff/validation-test/simple_math.swift b/test/AutoDiff/validation-test/simple_math.swift index 88b33e0ecfeaf..c01202d83c7cd 100644 --- a/test/AutoDiff/validation-test/simple_math.swift +++ b/test/AutoDiff/validation-test/simple_math.swift @@ -121,6 +121,32 @@ SimpleMathTests.test("MultipleResults") { expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapAndReturnProduct)) } +// Test function with multiple `inout` parameters and a custom pullback. +@differentiable(reverse) +func swapCustom(_ x: inout Float, _ y: inout Float) { + let tmp = x; x = y; y = tmp +} +@derivative(of: swapCustom) +func vjpSwapCustom(_ x: inout Float, _ y: inout Float) -> ( + value: Void, pullback: (inout Float, inout Float) -> Void +) { + swapCustom(&x, &y) + return ((), {v1, v2 in + let tmp = v1; v1 = v2; v2 = v1 + }) +} + +SimpleMathTests.test("MultipleResultsWithCustomPullback") { + func multiply_swapCustom(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swapCustom(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 + } + + expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapCustom)) + expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapCustom)) +} + SimpleMathTests.test("CaptureLocal") { let z: Float = 10 func foo(_ x: Float) -> Float { From 47fff2439250cac56eebd93a5843cf215204df40 Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Mon, 26 Jul 2021 19:01:53 -0500 Subject: [PATCH 03/12] Attempting to detect non-wrt inout parameters, all inouts are now results. --- lib/AST/Type.cpp | 16 ++++++++++++++++ lib/SIL/IR/SILFunctionType.cpp | 9 +-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 70820b99c3352..273b1a2ba8f37 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -6457,6 +6457,22 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( auto resultTanType = resultTan->getType(); resultTanTypes.push_back(resultTanType); } + // Append non-wrt inout result tangent spaces. + auto *resultFunctionType = this->getResult()->getAs(); + auto sourceFunction = resultFunctionType ? resultFunctionType : this; + for (unsigned i : range(sourceFunction->getNumParams())) { + auto param = sourceFunction->getParams()[i]; + if (parameterIndices->contains(i)) + continue; + if (param.isInOut()) { + auto resultType = param.getPlainType(); + auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance); + if (!resultTan) + continue; + auto resultTanType = resultTan->getType(); + resultTanTypes.push_back(resultTanType); + } + } // Error if no semantic result has a tangent space. if (resultTanTypes.empty() && !hasInoutResult) { diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 38412802f9c9d..4a8ac49ecfda1 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -238,8 +238,6 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { resultIndices.push_back(resultAndIndex.index()); // Check `inout` parameters. for (auto inoutParamAndIndex : enumerate(getIndirectMutatingParameters())) - // FIXME(TF-1305): The `getResults().empty()` condition is a hack. - // // Currently, an `inout` parameter can either be: // 1. Both a differentiability parameter and a differentiability result. // 2. `@noDerivative`: neither a differentiability parameter nor a @@ -251,13 +249,8 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { // cases, so supporting it is a non-goal. // // See TF-1305 for solution ideas. For now, `@noDerivative` `inout` - // parameters are not treated as differentiability results, unless the - // original function has no formal results, in which case all `inout` // parameters are treated as differentiability results. - if (getResults().empty() || - inoutParamAndIndex.value().getDifferentiability() != - SILParameterDifferentiability::NotDifferentiable) - resultIndices.push_back(getNumResults() + inoutParamAndIndex.index()); + resultIndices.push_back(getNumResults() + inoutParamAndIndex.index()); auto numSemanticResults = getNumResults() + getNumIndirectMutatingParameters(); return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices); From 2d49ad5b76aa9fe76c67d71d33946375741d3f23 Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Tue, 27 Jul 2021 17:54:00 -0500 Subject: [PATCH 04/12] Iterate over all result indices in getAutoDiffPullbackType() and getAutoDiffDifferentialType(), rather than the hardcoded 0 index. --- lib/SIL/IR/SILFunctionType.cpp | 12 ++++++++---- .../validation-test/forward_mode_simple.swift | 3 +++ test/AutoDiff/validation-test/simple_math.swift | 3 +++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 4a8ac49ecfda1..99509dfd3cb20 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -551,7 +551,9 @@ static CanSILFunctionType getAutoDiffDifferentialType( differentialParams.push_back({paramTanType, paramConv}); } SmallVector differentialResults; - for (auto resultIndex : resultIndices->getIndices()) { + // TODO(TF-1038): Instead iterate over resultIndices when these are specified + // via @differentiable(results:) or similar. + for (auto resultIndex : range(originalResults.size())) { // Handle formal original result. if (resultIndex < originalFnTy->getNumResults()) { auto &result = originalResults[resultIndex]; @@ -567,7 +569,7 @@ static CanSILFunctionType getAutoDiffDifferentialType( differentialResults.push_back({resultTanType, resultConv}); continue; } - // Handle original `inout` parameter. + // Handle original `inout` parameters. auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); auto inoutParamIt = std::next( originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); @@ -686,7 +688,9 @@ static CanSILFunctionType getAutoDiffPullbackType( // Collect pullback parameters. SmallVector pullbackParams; - for (auto resultIndex : resultIndices->getIndices()) { + // TODO(TF-1038): Instead iterate over resultIndices when these are specified + // via @differentiable(results:) or similar. + for (auto resultIndex : range(originalResults.size())) { // Handle formal original result. if (resultIndex < originalFnTy->getNumResults()) { auto &origRes = originalResults[resultIndex]; @@ -702,7 +706,7 @@ static CanSILFunctionType getAutoDiffPullbackType( pullbackParams.push_back({resultTanType, paramConv}); continue; } - // Handle original `inout` parameter. + // Handle `inout` parameters. auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); auto inoutParamIt = std::next( originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); diff --git a/test/AutoDiff/validation-test/forward_mode_simple.swift b/test/AutoDiff/validation-test/forward_mode_simple.swift index 990f8323bb8a2..c0bd82162414b 100644 --- a/test/AutoDiff/validation-test/forward_mode_simple.swift +++ b/test/AutoDiff/validation-test/forward_mode_simple.swift @@ -1063,6 +1063,8 @@ ForwardModeTests.test("FunctionCall") { expectEqual(3, derivative(at: 3) { x in foo(x, 4) }) } +// FIXME(TF-1038): Support differentiable functions returning tuples. +/* ForwardModeTests.test("ResultSelection") { func tuple(_ x: Float, _ y: Float) -> (Float, Float) { return (x + 1, y + 2) @@ -1081,6 +1083,7 @@ ForwardModeTests.test("ResultSelection") { expectEqual(1, derivative(at: 3, 3, of: tupleGenericSecond)) */ } +*/ // TODO(TF-983): Support forward-mode differentiation of multiple results. /* diff --git a/test/AutoDiff/validation-test/simple_math.swift b/test/AutoDiff/validation-test/simple_math.swift index c01202d83c7cd..b49a8e8504dfb 100644 --- a/test/AutoDiff/validation-test/simple_math.swift +++ b/test/AutoDiff/validation-test/simple_math.swift @@ -51,6 +51,8 @@ SimpleMathTests.test("FunctionCall") { expectEqual(3, gradient(at: 3) { x in foo(x, 4) }) } +// FIXME(TF-1038): Support differentiable functions returning tuples. +/* SimpleMathTests.test("ResultSelection") { func tuple(_ x: Float, _ y: Float) -> (Float, Float) { return (x + 1, y + 2) @@ -66,6 +68,7 @@ SimpleMathTests.test("ResultSelection") { expectEqual((1, 0), gradient(at: 3, 3, of: tupleGenericFirst)) expectEqual((0, 1), gradient(at: 3, 3, of: tupleGenericSecond)) } +*/ SimpleMathTests.test("MultipleResults") { // Test function returning a tuple of active results. From 98a833749693f5306b262a84e2ba492a46685ee6 Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Fri, 30 Jul 2021 15:15:32 -0500 Subject: [PATCH 05/12] Returning all resultIndices from emitDifferentiabilityWitnessesForFunction(), adding some serialization tests. --- lib/SIL/IR/SILFunctionType.cpp | 8 ++----- lib/SILGen/SILGen.cpp | 12 ++++++++-- .../Serialization/derivative_attr.swift | 20 ++++++++++++++++ .../Serialization/differentiable_attr.swift | 23 +++++++++++++++++++ .../differentiable_function.swift | 12 ++++++++++ .../validation-test/simple_math.swift | 5 +--- 6 files changed, 68 insertions(+), 12 deletions(-) diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 99509dfd3cb20..c39c7ab59eaeb 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -551,9 +551,7 @@ static CanSILFunctionType getAutoDiffDifferentialType( differentialParams.push_back({paramTanType, paramConv}); } SmallVector differentialResults; - // TODO(TF-1038): Instead iterate over resultIndices when these are specified - // via @differentiable(results:) or similar. - for (auto resultIndex : range(originalResults.size())) { + for (auto resultIndex : resultIndices->getIndices()) { // Handle formal original result. if (resultIndex < originalFnTy->getNumResults()) { auto &result = originalResults[resultIndex]; @@ -688,9 +686,7 @@ static CanSILFunctionType getAutoDiffPullbackType( // Collect pullback parameters. SmallVector pullbackParams; - // TODO(TF-1038): Instead iterate over resultIndices when these are specified - // via @differentiable(results:) or similar. - for (auto resultIndex : range(originalResults.size())) { + for (auto resultIndex : resultIndices->getIndices()) { // Handle formal original result. if (resultIndex < originalFnTy->getNumResults()) { auto &origRes = originalResults[resultIndex]; diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index b5c3a8ec9c9ce..3d4169c7a497d 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -1252,11 +1252,15 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( auto *AFD = constant.getAbstractFunctionDecl(); auto emitWitnesses = [&](DeclAttributes &Attrs) { for (auto *diffAttr : Attrs.getAttributes()) { - auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); assert((!F->getLoweredFunctionType()->getSubstGenericSignature() || diffAttr->getDerivativeGenericSignature()) && "Type-checking should resolve derivative generic signatures for " "all original SIL functions with generic signatures"); + auto numResults = + F->getLoweredFunctionType()->getNumResults() + + F->getLoweredFunctionType()->getNumIndirectMutatingParameters(); + auto *resultIndices = IndexSubset::getDefault( + getASTContext(), numResults, /*includeAll*/ true); auto witnessGenSig = autodiff::getDifferentiabilityWitnessGenericSignature( AFD->getGenericSignature(), @@ -1285,7 +1289,11 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( auto witnessGenSig = autodiff::getDifferentiabilityWitnessGenericSignature( origAFD->getGenericSignature(), AFD->getGenericSignature()); - auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); + auto numResults = + origFn->getLoweredFunctionType()->getNumResults() + + origFn->getLoweredFunctionType()->getNumIndirectMutatingParameters(); + auto *resultIndices = IndexSubset::getDefault( + getASTContext(), numResults, /*includeAll*/ true); AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices, witnessGenSig); emitDifferentiabilityWitness(origAFD, origFn, diff --git a/test/AutoDiff/Serialization/derivative_attr.swift b/test/AutoDiff/Serialization/derivative_attr.swift index 91677baa80e25..c41c0a36d1a50 100644 --- a/test/AutoDiff/Serialization/derivative_attr.swift +++ b/test/AutoDiff/Serialization/derivative_attr.swift @@ -37,6 +37,26 @@ func derivativeTop2( (y, { (dx, dy) in dy }) } +// Test top-level inout functions. + +func topInout1(_ x: inout S) {} + +// CHECK: @derivative(of: topInout1, wrt: x) +@derivative(of: topInout1) +func derivativeTopInout1(_ x: inout S) -> (value: Void, pullback: (inout S) -> Void) { + fatalError() +} + +func topInout2(_ x: inout S) -> S { + x +} + +// CHECK: @derivative(of: topInout2, wrt: x) +@derivative(of: topInout2) +func derivativeTopInout2(_ x: inout S) -> (value: S, pullback: (S, inout S) -> Void) { + fatalError() +} + // Test instance methods. extension S { diff --git a/test/AutoDiff/Serialization/differentiable_attr.swift b/test/AutoDiff/Serialization/differentiable_attr.swift index b8c83362bd813..e09f7541caf90 100644 --- a/test/AutoDiff/Serialization/differentiable_attr.swift +++ b/test/AutoDiff/Serialization/differentiable_attr.swift @@ -43,6 +43,29 @@ func testWrtClause(x: Float, y: Float) -> Float { return x } +// CHECK: @differentiable(reverse, wrt: x) +// CHECK-NEXT: func testInout(x: inout Float) +@differentiable(reverse) +func testInout(x: inout Float) { + x = x * 2.0 +} + +// CHECK: @differentiable(reverse, wrt: x) +// CHECK-NEXT: func testInoutResult(x: inout Float) -> Float +@differentiable(reverse) +func testInoutResult(x: inout Float) -> Float { + x = x * 2.0 + return x +} + +// CHECK: @differentiable(reverse, wrt: (x, y)) +// CHECK-NEXT: func testMultipleInout(x: inout Float, y: inout Float) +@differentiable(reverse) +func testMultipleInout(x: inout Float, y: inout Float) { + x = x * y + y = x +} + struct InstanceMethod : Differentiable { // CHECK: @differentiable(reverse, wrt: (self, y)) // CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float diff --git a/test/AutoDiff/Serialization/differentiable_function.swift b/test/AutoDiff/Serialization/differentiable_function.swift index 316a0a6eca40d..e31d874bb8920 100644 --- a/test/AutoDiff/Serialization/differentiable_function.swift +++ b/test/AutoDiff/Serialization/differentiable_function.swift @@ -15,3 +15,15 @@ func b(_ f: @differentiable(_linear) (Float) -> Float) {} func c(_ f: @differentiable(reverse) (Float, @noDerivative Float) -> Float) {} // CHECK: func c(_ f: @differentiable(reverse) (Float, @noDerivative Float) -> Float) + +func d(_ f: @differentiable(reverse) (inout Float) -> ()) {} +// CHECK: func d(_ f: @differentiable(reverse) (inout Float) -> ()) + +func e(_ f: @differentiable(reverse) (inout Float, inout Float) -> ()) {} +// CHECK: func e(_ f: @differentiable(reverse) (inout Float, inout Float) -> ()) + +func f(_ f: @differentiable(reverse) (inout Float) -> Float) {} +// CHECK: func f(_ f: @differentiable(reverse) (inout Float) -> Float) + +func g(_ f: @differentiable(reverse) (inout Float, Float) -> Float) {} +// CHECK: func g(_ f: @differentiable(reverse) (inout Float, Float) -> Float) diff --git a/test/AutoDiff/validation-test/simple_math.swift b/test/AutoDiff/validation-test/simple_math.swift index b49a8e8504dfb..f45ff7ac396ce 100644 --- a/test/AutoDiff/validation-test/simple_math.swift +++ b/test/AutoDiff/validation-test/simple_math.swift @@ -51,8 +51,6 @@ SimpleMathTests.test("FunctionCall") { expectEqual(3, gradient(at: 3) { x in foo(x, 4) }) } -// FIXME(TF-1038): Support differentiable functions returning tuples. -/* SimpleMathTests.test("ResultSelection") { func tuple(_ x: Float, _ y: Float) -> (Float, Float) { return (x + 1, y + 2) @@ -68,7 +66,6 @@ SimpleMathTests.test("ResultSelection") { expectEqual((1, 0), gradient(at: 3, 3, of: tupleGenericFirst)) expectEqual((0, 1), gradient(at: 3, 3, of: tupleGenericSecond)) } -*/ SimpleMathTests.test("MultipleResults") { // Test function returning a tuple of active results. @@ -135,7 +132,7 @@ func vjpSwapCustom(_ x: inout Float, _ y: inout Float) -> ( ) { swapCustom(&x, &y) return ((), {v1, v2 in - let tmp = v1; v1 = v2; v2 = v1 + let tmp = v1; v1 = v2; v2 = tmp }) } From 6ce7df55e0d2124810b2b6e56b925148dad69a8d Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Mon, 2 Aug 2021 10:54:44 -0500 Subject: [PATCH 06/12] Adding test for cross-module registration of functions with multiple semantic results. --- .../Inputs/a.swift | 22 +++++++++++++++++++ .../Inputs/b.swift | 15 +++++++++++++ 2 files changed, 37 insertions(+) diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift index 59ec26ef9bd08..8942432a0a2a9 100644 --- a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/a.swift @@ -1,3 +1,25 @@ +import _Differentiation + public struct Struct { public func method(_ x: Float) -> Float { x } } + +// Test cross-module recognition of functions with multiple semantic results. +@differentiable(reverse) +public func swap(_ x: inout Float, _ y: inout Float) { + let tmp = x; x = y; y = tmp +} + +@differentiable(reverse) +public func swapCustom(_ x: inout Float, _ y: inout Float) { + let tmp = x; x = y; y = tmp +} +@derivative(of: swapCustom) +public func vjpSwapCustom(_ x: inout Float, _ y: inout Float) -> ( + value: Void, pullback: (inout Float, inout Float) -> Void +) { + swapCustom(&x, &y) + return ((), {v1, v2 in + let tmp = v1; v1 = v2; v2 = tmp + }) +} diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift index 7947518708ad3..f6d3947757602 100644 --- a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift @@ -11,3 +11,18 @@ extension Struct: Differentiable { (x, { $0 }) } } + +// Test cross-module recognition of functions with multiple semantic results. +@differentiable(reverse) +func multiply_swap(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swap(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 +} + +@differentiable(reverse) +func multiply_swapCustom(_ x: Float, _ y: Float) -> Float { + var tuple = (x, y) + swapCustom(&tuple.0, &tuple.1) + return tuple.0 * tuple.1 +} \ No newline at end of file From 2420dbc8b6a14a099bcf56b05eb7d2b5d8e86915 Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Wed, 4 Aug 2021 09:45:39 -0600 Subject: [PATCH 07/12] Loosening checks that assume only one result. Adjusting tests. --- lib/SILOptimizer/Differentiation/Common.cpp | 4 ---- lib/SILOptimizer/Differentiation/Thunk.cpp | 4 ++-- .../Sema/DerivativeRegistrationCrossModule/Inputs/b.swift | 2 +- test/AutoDiff/Sema/derivative_attr_type_checking.swift | 4 ++-- test/AutoDiff/validation-test/forward_mode_simple.swift | 3 --- 5 files changed, 5 insertions(+), 12 deletions(-) diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index 33e0a3fa6f067..15056d1fd2e39 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -492,10 +492,6 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original, SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( SILModule &module, SILFunction *original, DifferentiabilityKind kind, IndexSubset *parameterIndices, IndexSubset *resultIndices) { - // AST differentiability witnesses always have a single result. - if (resultIndices->getCapacity() != 1 || !resultIndices->contains(0)) - return nullptr; - // Explicit differentiability witnesses only exist on SIL functions that come // from AST functions. auto *originalAFD = findAbstractFunctionDecl(original); diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index f1a94e4320acd..5f4f1ca1b8a77 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -769,8 +769,8 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( /*withoutActuallyEscaping*/ false); } assert(origFnType->getNumResults() + - origFnType->getNumIndirectMutatingParameters() == - 1); + origFnType->getNumIndirectMutatingParameters() > + 0); if (origFnType->getNumResults() > 0 && origFnType->getResults().front().isFormalDirect()) { auto result = diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift index f6d3947757602..5485a5f9a68b7 100644 --- a/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift @@ -25,4 +25,4 @@ func multiply_swapCustom(_ x: Float, _ y: Float) -> Float { var tuple = (x, y) swapCustom(&tuple.0, &tuple.1) return tuple.0 * tuple.1 -} \ No newline at end of file +} diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 2fb038dd3fe58..8585659d86f07 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -890,13 +890,13 @@ func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {} extension InoutParameters { func multipleSemanticResults(_ x: inout Float) -> Float { x } - @derivative(of: multipleSemanticResults) + @derivative(of: multipleSemanticResults, wrt: x) func vjpMultipleSemanticResults(_ x: inout Float) -> ( value: Float, pullback: (Float, inout Float) -> Void ) { fatalError() } func inoutVoid(_ x: Float, _ void: inout Void) -> Float {} - @derivative(of: inoutVoid) + @derivative(of: inoutVoid, wrt: (x, void)) func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> ( value: Float, pullback: (Float) -> Float ) { fatalError() } diff --git a/test/AutoDiff/validation-test/forward_mode_simple.swift b/test/AutoDiff/validation-test/forward_mode_simple.swift index c0bd82162414b..990f8323bb8a2 100644 --- a/test/AutoDiff/validation-test/forward_mode_simple.swift +++ b/test/AutoDiff/validation-test/forward_mode_simple.swift @@ -1063,8 +1063,6 @@ ForwardModeTests.test("FunctionCall") { expectEqual(3, derivative(at: 3) { x in foo(x, 4) }) } -// FIXME(TF-1038): Support differentiable functions returning tuples. -/* ForwardModeTests.test("ResultSelection") { func tuple(_ x: Float, _ y: Float) -> (Float, Float) { return (x + 1, y + 2) @@ -1083,7 +1081,6 @@ ForwardModeTests.test("ResultSelection") { expectEqual(1, derivative(at: 3, 3, of: tupleGenericSecond)) */ } -*/ // TODO(TF-983): Support forward-mode differentiation of multiple results. /* From f5d4885247abe0bc4805feffe1f70936b9ed7e4f Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Wed, 4 Aug 2021 20:19:35 -0600 Subject: [PATCH 08/12] Reworking logic for non-wrt inout parameters. Replacing single result index with all result indices in multiple places. --- lib/AST/Type.cpp | 60 +++++++++++++++++++++++++------- lib/SIL/IR/SILDeclRef.cpp | 8 ++++- lib/Sema/TypeCheckAttr.cpp | 20 +++++++++-- lib/Sema/TypeCheckProtocol.cpp | 8 ++++- lib/Serialization/ModuleFile.cpp | 11 ++++-- 5 files changed, 86 insertions(+), 21 deletions(-) diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 273b1a2ba8f37..81ec467dff928 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -6444,11 +6444,14 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( bool hasInoutResult = false; for (auto i : range(originalResults.size())) { auto originalResult = originalResults[i]; + auto originalResultType = originalResult.type; + // Voids currently have a defined tangent vector, so ignore them. + if (originalResultType->isVoid()) + continue; if (originalResult.isInout) { hasInoutResult = true; continue; } - auto originalResultType = originalResult.type; // Get the original semantic result type's `TangentVector` associated type. auto resultTan = originalResultType->getAutoDiffTangentSpace(lookupConformance); @@ -6458,19 +6461,48 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( resultTanTypes.push_back(resultTanType); } // Append non-wrt inout result tangent spaces. - auto *resultFunctionType = this->getResult()->getAs(); - auto sourceFunction = resultFunctionType ? resultFunctionType : this; - for (unsigned i : range(sourceFunction->getNumParams())) { - auto param = sourceFunction->getParams()[i]; - if (parameterIndices->contains(i)) - continue; - if (param.isInOut()) { - auto resultType = param.getPlainType(); - auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance); - if (!resultTan) + // This uses the logic from getSubsetParameters(), only operating over all + // parameter indices and looking for non-wrt indices. + SmallVector curryLevels; + // An inlined version of unwrapCurryLevels(). + AnyFunctionType *fnTy = this; + while (fnTy != nullptr) { + curryLevels.push_back(fnTy); + fnTy = fnTy->getResult()->getAs(); + } + + SmallVector curryLevelParameterIndexOffsets(curryLevels.size()); + unsigned currentOffset = 0; + for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) { + curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset; + currentOffset += curryLevels[curryLevelIndex]->getNumParams(); + } + + if (!makeSelfParamFirst) { + std::reverse(curryLevels.begin(), curryLevels.end()); + std::reverse(curryLevelParameterIndexOffsets.begin(), + curryLevelParameterIndexOffsets.end()); + } + + for (unsigned curryLevelIndex : indices(curryLevels)) { + auto *curryLevel = curryLevels[curryLevelIndex]; + unsigned parameterIndexOffset = + curryLevelParameterIndexOffsets[curryLevelIndex]; + for (unsigned paramIndex : range(curryLevel->getNumParams())) { + if (parameterIndices->contains(parameterIndexOffset + paramIndex)) continue; - auto resultTanType = resultTan->getType(); - resultTanTypes.push_back(resultTanType); + + auto param = curryLevel->getParams()[paramIndex]; + if (param.isInOut()) { + auto resultType = param.getPlainType(); + if (resultType->isVoid()) + continue; + auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance); + if (!resultTan) + continue; + auto resultTanType = resultTan->getType(); + resultTanTypes.push_back(resultTanType); + } } } @@ -6565,6 +6597,8 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( std::make_pair(paramType, i)); } if (diffParam.isInOut()) { + if (paramType->isVoid()) + continue; inoutParams.push_back(diffParam); continue; } diff --git a/lib/SIL/IR/SILDeclRef.cpp b/lib/SIL/IR/SILDeclRef.cpp index 1907913f9528f..e49c8464863a3 100644 --- a/lib/SIL/IR/SILDeclRef.cpp +++ b/lib/SIL/IR/SILDeclRef.cpp @@ -853,7 +853,13 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const { auto *silParameterIndices = autodiff::getLoweredParameterIndices( derivativeFunctionIdentifier->getParameterIndices(), getDecl()->getInterfaceType()->castTo()); - auto *resultIndices = IndexSubset::get(getDecl()->getASTContext(), 1, {0}); + auto originalFn = + getDecl()->getInterfaceType()->castTo(); + SmallVector semanticResults; + autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); + auto numResults = semanticResults.size(); + auto *resultIndices = IndexSubset::getDefault( + getDecl()->getASTContext(), numResults, /*includeAll*/ true); AutoDiffConfig silConfig( silParameterIndices, resultIndices, derivativeFunctionIdentifier->getDerivativeGenericSignature()); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 827428412a6fe..f82fbcb1f2072 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -5073,7 +5073,12 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate( } getterDecl->getAttrs().add(newAttr); // Register derivative function configuration. - auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + auto originalFn = getterDecl->getInterfaceType()->castTo(); + SmallVector semanticResults; + autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); + auto numResults = semanticResults.size(); + auto *resultIndices = IndexSubset::getDefault( + ctx, numResults, /*includeAll*/ true); getterDecl->addDerivativeFunctionConfiguration( {resolvedDiffParamIndices, resultIndices, derivativeGenSig}); return resolvedDiffParamIndices; @@ -5088,7 +5093,11 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate( return nullptr; } // Register derivative function configuration. - auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + SmallVector semanticResults; + autodiff::getFunctionSemanticResultTypes(originalFnRemappedTy, semanticResults); + auto numResults = semanticResults.size(); + auto *resultIndices = IndexSubset::getDefault( + ctx, numResults, /*includeAll*/ true); original->addDerivativeFunctionConfiguration( {resolvedDiffParamIndices, resultIndices, derivativeGenSig}); return resolvedDiffParamIndices; @@ -5510,7 +5519,12 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) { } // Register derivative function configuration. - auto *resultIndices = IndexSubset::get(Ctx, 1, {0}); + auto originalFn = originalAFD->getInterfaceType()->castTo(); + SmallVector semanticResults; + autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); + auto numResults = semanticResults.size(); + auto *resultIndices = IndexSubset::getDefault( + Ctx, numResults, /*includeAll*/ true); originalAFD->addDerivativeFunctionConfiguration( {resolvedDiffParamIndices, resultIndices, derivative->getGenericSignature()}); diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index cc499e24aa585..123b8de37932a 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -490,7 +490,13 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, witness->getAttrs().add(newAttr); success = true; // Register derivative function configuration. - auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + auto originalFn = + witnessAFD->getInterfaceType()->castTo(); + SmallVector semanticResults; + autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); + auto numResults = semanticResults.size(); + auto *resultIndices = IndexSubset::getDefault( + ctx, numResults, /*includeAll*/ true); witnessAFD->addDerivativeFunctionConfiguration( {newAttr->getParameterIndices(), resultIndices, newAttr->getDerivativeGenericSignature()}); diff --git a/lib/Serialization/ModuleFile.cpp b/lib/Serialization/ModuleFile.cpp index 642fe6a2a21d0..bb5ced193f4e3 100644 --- a/lib/Serialization/ModuleFile.cpp +++ b/lib/Serialization/ModuleFile.cpp @@ -667,9 +667,14 @@ void ModuleFile::loadDerivativeFunctionConfigurations( } auto derivativeGenSig = derivativeGenSigOrError.get(); // NOTE(TF-1038): Result indices are currently unsupported in derivative - // registration attributes. In the meantime, always use `{0}` (wrt the - // first and only result). - auto resultIndices = IndexSubset::get(ctx, 1, {0}); + // registration attributes. In the meantime, always use all results. + auto originalFn = + originalAFD->getInterfaceType()->castTo(); + SmallVector semanticResults; + autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); + auto numResults = semanticResults.size(); + auto *resultIndices = IndexSubset::getDefault( + ctx, numResults, /*includeAll*/ true); results.insert({parameterIndices, resultIndices, derivativeGenSig}); } } From ff8dc58bfdeb39b9ad3fae3ed7fbfb6c0550a648 Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Thu, 5 Aug 2021 16:43:56 -0600 Subject: [PATCH 09/12] Consolidated repeated result index generation into a central function. --- include/swift/AST/AutoDiff.h | 5 +++++ lib/AST/AutoDiff.cpp | 10 ++++++++++ lib/IRGen/IRGenMangler.h | 8 ++++++-- lib/SIL/IR/SILDeclRef.cpp | 9 ++------- lib/Sema/TypeCheckAttr.cpp | 16 ++++------------ lib/Sema/TypeCheckProtocol.cpp | 9 ++------- lib/Serialization/ModuleFile.cpp | 9 ++------- 7 files changed, 31 insertions(+), 35 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 4b692d2e7f51a..096831f944a6c 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -32,6 +32,7 @@ namespace swift { +class AbstractFunctionDecl; class AnyFunctionType; class SourceFile; class SILFunctionType; @@ -575,6 +576,10 @@ void getFunctionSemanticResultTypes( SmallVectorImpl &result, GenericEnvironment *genericEnv = nullptr); +/// Returns the indices of all semantic results for a given function. +IndexSubset *getAllFunctionSemanticResultIndices( + const AbstractFunctionDecl *AFD); + /// Returns the lowered SIL parameter indices for the given AST parameter /// indices and `AnyfunctionType`. /// diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 67c7b88a49689..8a2ae3b326d42 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -211,6 +211,16 @@ void autodiff::getFunctionSemanticResultTypes( } } +IndexSubset * +autodiff::getAllFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD) { + auto originalFn = AFD->getInterfaceType()->castTo(); + SmallVector semanticResults; + autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); + auto numResults = semanticResults.size(); + return IndexSubset::getDefault( + AFD->getASTContext(), numResults, /*includeAll*/ true); +} + // TODO(TF-874): Simplify this helper. See TF-874 for WIP. IndexSubset * autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices, diff --git a/lib/IRGen/IRGenMangler.h b/lib/IRGen/IRGenMangler.h index 28faba3baa527..f349ec805281a 100644 --- a/lib/IRGen/IRGenMangler.h +++ b/lib/IRGen/IRGenMangler.h @@ -57,9 +57,11 @@ class IRGenMangler : public Mangle::ASTMangler { AutoDiffDerivativeFunctionIdentifier *derivativeId) { beginManglingWithAutoDiffOriginalFunction(func); auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind()); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(func); AutoDiffConfig config( derivativeId->getParameterIndices(), - IndexSubset::get(func->getASTContext(), 1, {0}), + resultIndices, derivativeId->getDerivativeGenericSignature()); appendAutoDiffFunctionParts("TJ", kind, config); appendOperator("Tj"); @@ -86,9 +88,11 @@ class IRGenMangler : public Mangle::ASTMangler { AutoDiffDerivativeFunctionIdentifier *derivativeId) { beginManglingWithAutoDiffOriginalFunction(func); auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind()); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(func); AutoDiffConfig config( derivativeId->getParameterIndices(), - IndexSubset::get(func->getASTContext(), 1, {0}), + resultIndices, derivativeId->getDerivativeGenericSignature()); appendAutoDiffFunctionParts("TJ", kind, config); appendOperator("Tq"); diff --git a/lib/SIL/IR/SILDeclRef.cpp b/lib/SIL/IR/SILDeclRef.cpp index e49c8464863a3..f4640bdc14622 100644 --- a/lib/SIL/IR/SILDeclRef.cpp +++ b/lib/SIL/IR/SILDeclRef.cpp @@ -853,13 +853,8 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const { auto *silParameterIndices = autodiff::getLoweredParameterIndices( derivativeFunctionIdentifier->getParameterIndices(), getDecl()->getInterfaceType()->castTo()); - auto originalFn = - getDecl()->getInterfaceType()->castTo(); - SmallVector semanticResults; - autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); - auto numResults = semanticResults.size(); - auto *resultIndices = IndexSubset::getDefault( - getDecl()->getASTContext(), numResults, /*includeAll*/ true); + auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices( + asAutoDiffOriginalFunction().getAbstractFunctionDecl()); AutoDiffConfig silConfig( silParameterIndices, resultIndices, derivativeFunctionIdentifier->getDerivativeGenericSignature()); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index f82fbcb1f2072..95f5155d55c6e 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -5073,12 +5073,8 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate( } getterDecl->getAttrs().add(newAttr); // Register derivative function configuration. - auto originalFn = getterDecl->getInterfaceType()->castTo(); - SmallVector semanticResults; - autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); - auto numResults = semanticResults.size(); - auto *resultIndices = IndexSubset::getDefault( - ctx, numResults, /*includeAll*/ true); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(getterDecl); getterDecl->addDerivativeFunctionConfiguration( {resolvedDiffParamIndices, resultIndices, derivativeGenSig}); return resolvedDiffParamIndices; @@ -5519,12 +5515,8 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) { } // Register derivative function configuration. - auto originalFn = originalAFD->getInterfaceType()->castTo(); - SmallVector semanticResults; - autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); - auto numResults = semanticResults.size(); - auto *resultIndices = IndexSubset::getDefault( - Ctx, numResults, /*includeAll*/ true); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(originalAFD); originalAFD->addDerivativeFunctionConfiguration( {resolvedDiffParamIndices, resultIndices, derivative->getGenericSignature()}); diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 123b8de37932a..fafa782c6bed1 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -490,13 +490,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, witness->getAttrs().add(newAttr); success = true; // Register derivative function configuration. - auto originalFn = - witnessAFD->getInterfaceType()->castTo(); - SmallVector semanticResults; - autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); - auto numResults = semanticResults.size(); - auto *resultIndices = IndexSubset::getDefault( - ctx, numResults, /*includeAll*/ true); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(witnessAFD); witnessAFD->addDerivativeFunctionConfiguration( {newAttr->getParameterIndices(), resultIndices, newAttr->getDerivativeGenericSignature()}); diff --git a/lib/Serialization/ModuleFile.cpp b/lib/Serialization/ModuleFile.cpp index bb5ced193f4e3..ae9a3f4042923 100644 --- a/lib/Serialization/ModuleFile.cpp +++ b/lib/Serialization/ModuleFile.cpp @@ -668,13 +668,8 @@ void ModuleFile::loadDerivativeFunctionConfigurations( auto derivativeGenSig = derivativeGenSigOrError.get(); // NOTE(TF-1038): Result indices are currently unsupported in derivative // registration attributes. In the meantime, always use all results. - auto originalFn = - originalAFD->getInterfaceType()->castTo(); - SmallVector semanticResults; - autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults); - auto numResults = semanticResults.size(); - auto *resultIndices = IndexSubset::getDefault( - ctx, numResults, /*includeAll*/ true); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(originalAFD); results.insert({parameterIndices, resultIndices, derivativeGenSig}); } } From 499703571bc8a4207ffb8208e594585524b31a24 Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Fri, 6 Aug 2021 07:07:17 -0600 Subject: [PATCH 10/12] Converting a last few areas to use multiple result indices. --- lib/SILGen/SILGen.cpp | 14 ++++---------- lib/SILGen/SILGenThunk.cpp | 7 +++++-- lib/TBDGen/TBDGen.cpp | 15 ++++++++++----- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 3d4169c7a497d..1cf60c3ab8fe7 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -1256,11 +1256,8 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( diffAttr->getDerivativeGenericSignature()) && "Type-checking should resolve derivative generic signatures for " "all original SIL functions with generic signatures"); - auto numResults = - F->getLoweredFunctionType()->getNumResults() + - F->getLoweredFunctionType()->getNumIndirectMutatingParameters(); - auto *resultIndices = IndexSubset::getDefault( - getASTContext(), numResults, /*includeAll*/ true); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(AFD); auto witnessGenSig = autodiff::getDifferentiabilityWitnessGenericSignature( AFD->getGenericSignature(), @@ -1289,11 +1286,8 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( auto witnessGenSig = autodiff::getDifferentiabilityWitnessGenericSignature( origAFD->getGenericSignature(), AFD->getGenericSignature()); - auto numResults = - origFn->getLoweredFunctionType()->getNumResults() + - origFn->getLoweredFunctionType()->getNumIndirectMutatingParameters(); - auto *resultIndices = IndexSubset::getDefault( - getASTContext(), numResults, /*includeAll*/ true); + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(origAFD); AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices, witnessGenSig); emitDifferentiabilityWitness(origAFD, origFn, diff --git a/lib/SILGen/SILGenThunk.cpp b/lib/SILGen/SILGenThunk.cpp index 5476dc755b695..59d0c865d2478 100644 --- a/lib/SILGen/SILGenThunk.cpp +++ b/lib/SILGen/SILGenThunk.cpp @@ -547,11 +547,13 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk( SILGenFunctionBuilder builder(*this); auto originalFnDeclRef = derivativeFnDeclRef.asAutoDiffOriginalFunction(); Mangle::ASTMangler mangler; + auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices( + originalFnDeclRef.getAbstractFunctionDecl()); auto name = mangler.mangleAutoDiffDerivativeFunction( originalFnDeclRef.getAbstractFunctionDecl(), derivativeId->getKind(), AutoDiffConfig(derivativeId->getParameterIndices(), - IndexSubset::get(getASTContext(), 1, {0}), + resultIndices, derivativeId->getDerivativeGenericSignature()), /*isVTableThunk*/ true); auto *thunk = builder.getOrCreateFunction( @@ -571,7 +573,8 @@ SILFunction *SILGenModule::getOrCreateDerivativeVTableThunk( auto *loweredParamIndices = autodiff::getLoweredParameterIndices( derivativeId->getParameterIndices(), derivativeFnDecl->getInterfaceType()->castTo()); - auto *loweredResultIndices = IndexSubset::get(getASTContext(), 1, {0}); + auto *loweredResultIndices = autodiff::getAllFunctionSemanticResultIndices( + originalFnDeclRef.getAbstractFunctionDecl()); auto diffFn = SGF.B.createDifferentiableFunction( loc, loweredParamIndices, loweredResultIndices, originalFn); auto derivativeFn = SGF.B.createDifferentiableFunctionExtract( diff --git a/lib/TBDGen/TBDGen.cpp b/lib/TBDGen/TBDGen.cpp index afd7fd44559cb..c0989007011cc 100644 --- a/lib/TBDGen/TBDGen.cpp +++ b/lib/TBDGen/TBDGen.cpp @@ -730,22 +730,27 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) { // Add derivative function symbols. for (const auto *differentiableAttr : - AFD->getAttrs().getAttributes()) + AFD->getAttrs().getAttributes()) { + auto *resultIndices = + autodiff::getAllFunctionSemanticResultIndices(AFD); addDerivativeConfiguration( differentiableAttr->getDifferentiabilityKind(), AFD, AutoDiffConfig(differentiableAttr->getParameterIndices(), - IndexSubset::get(AFD->getASTContext(), 1, {0}), + resultIndices, differentiableAttr->getDerivativeGenericSignature())); + } for (const auto *derivativeAttr : - AFD->getAttrs().getAttributes()) + AFD->getAttrs().getAttributes()) { + auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices( + derivativeAttr->getOriginalFunction(AFD->getASTContext())); addDerivativeConfiguration( DifferentiabilityKind::Reverse, derivativeAttr->getOriginalFunction(AFD->getASTContext()), AutoDiffConfig(derivativeAttr->getParameterIndices(), - IndexSubset::get(AFD->getASTContext(), 1, {0}), + resultIndices, AFD->getGenericSignature())); - + } visitDefaultArguments(AFD, AFD->getParameters()); if (AFD->hasAsync()) { From ffe1e529f558c7eca276c3adce7ea02b7c0f2879 Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Thu, 12 Aug 2021 18:01:16 -0600 Subject: [PATCH 11/12] Added tuple result tests, extracting tuple elements as semantic result types. --- lib/AST/AutoDiff.cpp | 12 +++++- .../Sema/derivative_attr_type_checking.swift | 16 ++++++++ .../differentiable_attr_type_checking.swift | 13 +++++++ .../validation-test/simple_math.swift | 37 +++++++++++++++++++ 4 files changed, 76 insertions(+), 2 deletions(-) diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 8a2ae3b326d42..a517854234eb8 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -196,8 +196,16 @@ void autodiff::getFunctionSemanticResultTypes( functionType->getResult()->getAs()) { formalResultType = resultFunctionType->getResult(); } - if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) - result.push_back({remap(formalResultType), /*isInout*/ false}); + if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) { + // Separate tuple elements into individual results. + if (formalResultType->is()) { + for (auto elt : formalResultType->castTo()->getElements()) { + result.push_back({remap(elt.getType()), /*isInout*/ false}); + } + } else { + result.push_back({remap(formalResultType), /*isInout*/ false}); + } + } // Collect `inout` parameters as semantic results. for (auto param : functionType->getParams()) diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 8585659d86f07..690090699a9bf 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -902,6 +902,22 @@ extension InoutParameters { ) { fatalError() } } +// Test tuple results. + +extension InoutParameters { + func tupleResults(_ x: Float) -> (Float, Float) { (x, x) } + @derivative(of: tupleResults, wrt: x) + func vjpTupleResults(_ x: Float) -> ( + value: (Float, Float), pullback: (Float, Float) -> Float + ) { fatalError() } + + func tupleResultsInt(_ x: Float) -> (Int, Float) { (1, x) } + @derivative(of: tupleResultsInt, wrt: x) + func vjpTupleResults(_ x: Float) -> ( + value: (Int, Float), pullback: (Float) -> Float + ) { fatalError() } +} + // Test original/derivative function `inout` parameter mismatches. extension InoutParameters { diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index 16cfecfd23e1a..b9c2fee69d680 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -696,6 +696,19 @@ extension InoutParameters { mutating func mutatingMethod(_ other: Self) -> Self {} } +// Test tuple results. + +extension InoutParameters { + @differentiable(reverse) + static func tupleResults(_ x: Self) -> (Self, Self) {} + + @differentiable(reverse) + static func tupleResultsInt(_ x: Self) -> (Int, Self) {} + + @differentiable(reverse) + static func tupleResultsInt2(_ x: Self) -> (Self, Int) {} +} + // Test accessors: `set`, `_read`, `_modify`. struct Accessors: Differentiable { diff --git a/test/AutoDiff/validation-test/simple_math.swift b/test/AutoDiff/validation-test/simple_math.swift index f45ff7ac396ce..408a991cd2bff 100644 --- a/test/AutoDiff/validation-test/simple_math.swift +++ b/test/AutoDiff/validation-test/simple_math.swift @@ -147,6 +147,43 @@ SimpleMathTests.test("MultipleResultsWithCustomPullback") { expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapCustom)) } +// Test functions returning tuples. +@differentiable(reverse) +func swapTuple(_ x: Float, _ y: Float) -> (Float, Float) { + return (y, x) +} + +@differentiable(reverse) +func swapTupleCustom(_ x: Float, _ y: Float) -> (Float, Float) { + return (y, x) +} +@derivative(of: swapTupleCustom) +func vjpSwapTupleCustom(_ x: Float, _ y: Float) -> ( + value: (Float, Float), pullback: (Float, Float) -> (Float, Float) +) { + return (swapTupleCustom(x, y), {v1, v2 in + return (v2, v1) + }) +} + +SimpleMathTests.test("ReturningTuples") { + func multiply_swapTuple(_ x: Float, _ y: Float) -> Float { + let result = swapTuple(x, y) + return result.0 * result.1 + } + + expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapTuple)) + expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapTuple)) + + func multiply_swapTupleCustom(_ x: Float, _ y: Float) -> Float { + let result = swapTupleCustom(x, y) + return result.0 * result.1 + } + + expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapTupleCustom)) + expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapTupleCustom)) +} + SimpleMathTests.test("CaptureLocal") { let z: Float = 10 func foo(_ x: Float) -> Float { From 17e69879678e5628423fc31a5e90daa81660636f Mon Sep 17 00:00:00 2001 From: Brad Larson Date: Thu, 11 Aug 2022 18:29:18 -0500 Subject: [PATCH 12/12] Adding @asl's fix for subset parameters thunks involving functions with multiple results, and an activity analysis test representing code that had exposed that issue. Co-authored-by: Anton Korobeynikov --- lib/AST/Type.cpp | 3 +- lib/SILOptimizer/Differentiation/Common.cpp | 3 +- lib/SILOptimizer/Differentiation/Thunk.cpp | 77 ++++++++++++++++--- .../SILOptimizer/activity_analysis.swift | 38 +++++++++ 4 files changed, 107 insertions(+), 14 deletions(-) diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 81ec467dff928..22dedc0cc4cb2 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -6555,9 +6555,8 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( SmallVector differentialResults; for (auto i : range(resultTanTypes.size())) { auto resultTanType = resultTanTypes[i]; - auto flags = ParameterTypeFlags().withInOut(false); differentialResults.push_back( - TupleTypeElt(resultTanType, Identifier(), flags)); + TupleTypeElt(resultTanType, Identifier())); } differentialResult = TupleType::get(differentialResults, ctx); } diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index 15056d1fd2e39..1e640fe2cd053 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -232,7 +232,8 @@ void collectMinimalIndicesForFunctionCall( auto ¶m = paramAndIdx.value(); if (!param.isIndirectMutating()) continue; - unsigned idx = paramAndIdx.index(); + unsigned idx = + paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults(); auto inoutArg = ai->getArgument(idx); results.push_back(inoutArg); resultIndices.push_back(inoutParamResultIndex++); diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index 5f4f1ca1b8a77..b44ba46db8d19 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -472,6 +472,12 @@ getOrCreateSubsetParametersThunkForLinearMap( return mappedIndex; }; + auto toIndirectResultsIter = thunk->getIndirectResults().begin(); + auto useNextIndirectResult = [&]() { + assert(toIndirectResultsIter != thunk->getIndirectResults().end()); + arguments.push_back(*toIndirectResultsIter++); + }; + switch (kind) { // Differential arguments are: // - All indirect results, followed by: @@ -480,9 +486,29 @@ getOrCreateSubsetParametersThunkForLinearMap( // indices). // - Zeros (when parameter is not in desired indices). case AutoDiffDerivativeFunctionKind::JVP: { - // Forward all indirect results. - arguments.append(thunk->getIndirectResults().begin(), - thunk->getIndirectResults().end()); + unsigned numIndirectResults = linearMapType->getNumIndirectFormalResults(); + // Forward desired indirect results. + for (unsigned idx : *actualConfig.resultIndices) { + if (idx >= numIndirectResults) + break; + + auto resultInfo = linearMapType->getResults()[idx]; + assert(idx < linearMapType->getNumResults()); + + // Forward result argument in case we do not need to thunk it away. + if (desiredConfig.resultIndices->contains(idx)) { + useNextIndirectResult(); + continue; + } + + // Otherwise, allocate and use an uninitialized indirect result. + auto *indirectResult = builder.createAllocStack( + loc, resultInfo.getSILStorageInterfaceType()); + localAllocations.push_back(indirectResult); + arguments.push_back(indirectResult); + } + assert(toIndirectResultsIter == thunk->getIndirectResults().end()); + auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin(); auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); }; // Iterate over actual indices. @@ -507,10 +533,6 @@ getOrCreateSubsetParametersThunkForLinearMap( // - Zeros (when parameter is not in desired indices). // - All actual arguments. case AutoDiffDerivativeFunctionKind::VJP: { - auto toIndirectResultsIter = thunk->getIndirectResults().begin(); - auto useNextIndirectResult = [&]() { - arguments.push_back(*toIndirectResultsIter++); - }; // Collect pullback arguments. unsigned pullbackResultIndex = 0; for (unsigned i : actualConfig.parameterIndices->getIndices()) { @@ -539,8 +561,18 @@ getOrCreateSubsetParametersThunkForLinearMap( arguments.push_back(indirectResult); } // Forward all actual non-indirect-result arguments. - arguments.append(thunk->getArgumentsWithoutIndirectResults().begin(), - thunk->getArgumentsWithoutIndirectResults().end() - 1); + auto thunkArgs = thunk->getArgumentsWithoutIndirectResults(); + // Slice out the function to be called. + thunkArgs = thunkArgs.slice(0, thunkArgs.size() - 1); + unsigned thunkArg = 0; + for (unsigned idx : *actualConfig.resultIndices) { + // Forward result argument in case we do not need to thunk it away. + if (desiredConfig.resultIndices->contains(idx)) + arguments.push_back(thunkArgs[thunkArg++]); + else { // Otherwise, zero it out. + buildZeroArgument(linearMapType->getParameters()[arguments.size()]); + } + } break; } } @@ -550,10 +582,33 @@ getOrCreateSubsetParametersThunkForLinearMap( auto *ai = builder.createApply(loc, linearMap, SubstitutionMap(), arguments); // If differential thunk, deallocate local allocations and directly return - // `apply` result. + // `apply` result (if it is desired). if (kind == AutoDiffDerivativeFunctionKind::JVP) { + SmallVector differentialDirectResults; + extractAllElements(ai, builder, differentialDirectResults); + SmallVector allResults; + collectAllActualResultsInTypeOrder(ai, differentialDirectResults, + allResults); + unsigned numResults = thunk->getConventions().getNumDirectSILResults() + + thunk->getConventions().getNumDirectSILResults(); + SmallVector results; + for (unsigned idx : *actualConfig.resultIndices) { + if (idx >= numResults) + break; + + auto result = allResults[idx]; + if (desiredConfig.isWrtResult(idx)) + results.push_back(result); + else { + if (result->getType().isAddress()) + builder.emitDestroyAddrAndFold(loc, result); + else + builder.emitDestroyValueOperation(loc, result); + } + } cleanupValues(); - builder.createReturn(loc, ai); + auto result = joinElements(results, builder, loc); + builder.createReturn(loc, result); return {thunk, interfaceSubs}; } diff --git a/test/AutoDiff/SILOptimizer/activity_analysis.swift b/test/AutoDiff/SILOptimizer/activity_analysis.swift index 3281dbe7f2358..850a5c6e25489 100644 --- a/test/AutoDiff/SILOptimizer/activity_analysis.swift +++ b/test/AutoDiff/SILOptimizer/activity_analysis.swift @@ -533,6 +533,44 @@ func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float { // CHECK: [ACTIVE] %13 = begin_access [read] [static] %2 : $*Float // CHECK: [ACTIVE] %14 = load [trivial] %13 : $*Float +public struct ArrayWrapper: Differentiable { + var values: [Float] + + @differentiable(reverse) + mutating func get(index: Int) -> Float { + self.values[index] + } + + // Check `inout` with result. + + // CHECK-LABEL: [AD] Activity info for ${{.*}}get{{.*}} at parameter indices (1) and result indices (0, 1) + // CHECK: bb0: + // CHECK: [USEFUL] %0 = argument of bb0 : $Int + // CHECK: [ACTIVE] %1 = argument of bb0 : $*ArrayWrapper + // CHECK: [ACTIVE] %4 = begin_access [read] [static] %1 : $*ArrayWrapper + // CHECK: [ACTIVE] %5 = struct_element_addr %4 : $*ArrayWrapper, #ArrayWrapper.values + // CHECK: [ACTIVE] %6 = load_borrow %5 : $*Array + // CHECK: [ACTIVE] %7 = alloc_stack $Float + // CHECK: [NONE] // function_ref Array.subscript.getter + // CHECK: %8 = function_ref @$sSayxSicig : $@convention(method) <τ_0_0> (Int, @guaranteed Array<τ_0_0>) -> @out τ_0_0 + // CHECK: [NONE] %9 = apply %8(%7, %0, %6) : $@convention(method) <τ_0_0> (Int, @guaranteed Array<τ_0_0>) -> @out τ_0_0 + // CHECK: [ACTIVE] %10 = load [trivial] %7 : $*Float +} + +@differentiable(reverse) +func testInoutAndResult(x: Int, y: inout ArrayWrapper) { + let _ = y.get(index: x) +} + +// CHECK-LABEL: [AD] Activity info for ${{.*}}testInoutAndResult{{.*}} at parameter indices (1) and result indices (0) +// CHECK: bb0: +// CHECK: [USEFUL] %0 = argument of bb0 : $Int +// CHECK: [ACTIVE] %1 = argument of bb0 : $*ArrayWrapper +// CHECK: [ACTIVE] %4 = begin_access [modify] [static] %1 : $*ArrayWrapper +// CHECK: [NONE] // function_ref ArrayWrapper.get(index:) +// CHECK: %5 = function_ref @$s17activity_analysis12ArrayWrapperV3get5indexSfSi_tF : $@convention(method) (Int, @inout ArrayWrapper) -> Float +// CHECK: [VARIED] %6 = apply %5(%0, %4) : $@convention(method) (Int, @inout ArrayWrapper) -> Float + //===----------------------------------------------------------------------===// // Throwing function differentiation (`try_apply`) //===----------------------------------------------------------------------===//