From 7d1a94bbb22f73936d7323447e9e9fdfc68fcf07 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 11 Jul 2020 11:15:24 -0700 Subject: [PATCH] [AutoDiff] Serialize and print `@derivative` and `@transpose` accessor kind. Serialize and print the optional accessor kind in `@derivative` and `@transpose` attributes. Resolves TF-1293. --- include/swift/AST/Attr.h | 19 +++++---- lib/AST/Attr.cpp | 14 ++++++- lib/AST/AutoDiff.cpp | 8 ++++ lib/Serialization/Deserialization.cpp | 16 ++++++-- lib/Serialization/ModuleFormat.h | 4 +- lib/Serialization/Serialization.cpp | 10 ++++- .../Serialization/derivative_attr.swift | 41 ++++++++++++++++++- .../Serialization/transpose_attr.swift | 5 ++- 8 files changed, 99 insertions(+), 18 deletions(-) diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index 98eb453af7a36..25fa621adf0bc 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1715,13 +1715,6 @@ class OriginallyDefinedInAttr: public DeclAttribute { } }; -/// A declaration name with location. -struct DeclNameRefWithLoc { - DeclNameRef Name; - DeclNameLoc Loc; - Optional AccessorKind; -}; - /// Attribute that marks a function as differentiable. /// /// Examples: @@ -1847,6 +1840,18 @@ class DifferentiableAttr final } }; +/// A declaration name with location. +struct DeclNameRefWithLoc { + /// The declaration name. + DeclNameRef Name; + /// The declaration name location. + DeclNameLoc Loc; + /// An optional accessor kind. + Optional AccessorKind; + + void print(ASTPrinter &Printer) const; +}; + /// The `@derivative(of:)` attribute registers a function as a derivative of /// another function-like declaration: a 'func', 'init', 'subscript', or 'var' /// computed property declaration. diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index 5369fe49b2251..ac9989bc9cea0 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -1052,7 +1052,9 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options, Printer.printAttrName("@derivative"); Printer << "(of: "; auto *attr = cast(this); - Printer << attr->getOriginalFunctionName().Name; + if (auto *baseType = attr->getBaseTypeRepr()) + baseType->print(Printer, Options); + attr->getOriginalFunctionName().print(Printer); auto *derivative = cast(D); auto diffParamsString = getDifferentiationParametersClauseString( derivative, attr->getParameterIndices(), attr->getParsedParameters(), @@ -1067,7 +1069,9 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options, Printer.printAttrName("@transpose"); Printer << "(of: "; auto *attr = cast(this); - Printer << attr->getOriginalFunctionName().Name; + if (auto *baseType = attr->getBaseTypeRepr()) + baseType->print(Printer, Options); + attr->getOriginalFunctionName().print(Printer); auto *transpose = cast(D); auto transParamsString = getDifferentiationParametersClauseString( transpose, attr->getParameterIndices(), attr->getParsedParameters(), @@ -1719,6 +1723,12 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment( return original->getGenericEnvironment(); } +void DeclNameRefWithLoc::print(ASTPrinter &Printer) const { + Printer << Name; + if (AccessorKind) + Printer << '.' << getAccessorLabel(*AccessorKind); +} + void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause) const { StreamPrinter P(OS); diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 83b7d981f08db..da55b1f0cc01c 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -422,6 +422,14 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const { } } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const DeclNameRefWithLoc &name) { + os << name.Name; + if (auto accessorKind = name.AccessorKind) + os << '.' << getAccessorLabel(*accessorKind); + return os; +} + bool swift::operator==(const TangentPropertyInfo::Error &lhs, const TangentPropertyInfo::Error &rhs) { if (lhs.kind != rhs.kind) diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index 4237e03aa6e41..85ae497d11c6b 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -4371,16 +4371,26 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() { case decls_block::Derivative_DECL_ATTR: { bool isImplicit; uint64_t origNameId; + bool hasAccessorKind; + uint64_t rawAccessorKind; DeclID origDeclId; uint64_t rawDerivativeKind; ArrayRef parameters; serialization::decls_block::DerivativeDeclAttrLayout::readRecord( - scratch, isImplicit, origNameId, origDeclId, rawDerivativeKind, - parameters); + scratch, isImplicit, origNameId, hasAccessorKind, rawAccessorKind, + origDeclId, rawDerivativeKind, parameters); + + Optional accessorKind = None; + if (hasAccessorKind) { + auto maybeAccessorKind = getActualAccessorKind(rawAccessorKind); + if (!maybeAccessorKind) + MF.fatal(); + accessorKind = *maybeAccessorKind; + } DeclNameRefWithLoc origName{DeclNameRef(MF.getDeclBaseName(origNameId)), - DeclNameLoc(), None}; + DeclNameLoc(), accessorKind}; auto derivativeKind = getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind); if (!derivativeKind) diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 9bf7c6b8e4647..19666c1920e38 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 563; // unchecked_value_cast +const uint16_t SWIFTMODULE_VERSION_MINOR = 564; // `@derivative` attribute accessor kind /// A standard hash seed used for all string hashes in a serialized module. /// @@ -1848,6 +1848,8 @@ namespace decls_block { Derivative_DECL_ATTR, BCFixed<1>, // Implicit flag. IdentifierIDField, // Original name. + BCFixed<1>, // Has original accessor kind? + AccessorKindField, // Original accessor kind. DeclIDField, // Original function declaration. AutoDiffDerivativeFunctionKindField, // Derivative function kind. BCArray> // Differentiation parameter indices' bitvector. diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 6b6d26d1e5646..a2acea8551083 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -2431,11 +2431,16 @@ class Serializer::DeclSerializer : public DeclVisitor { assert(attr->getOriginalFunction(ctx) && "`@derivative` attribute should have original declaration set " "during construction or parsing"); - auto origName = attr->getOriginalFunctionName().Name.getBaseName(); + auto origDeclNameRef = attr->getOriginalFunctionName(); + auto origName = origDeclNameRef.Name.getBaseName(); IdentifierID origNameId = S.addDeclBaseNameRef(origName); DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction(ctx)); auto derivativeKind = getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind()); + uint8_t rawAccessorKind = 0; + auto origAccessorKind = origDeclNameRef.AccessorKind; + if (origAccessorKind) + rawAccessorKind = uint8_t(getStableAccessorKind(*origAccessorKind)); auto *parameterIndices = attr->getParameterIndices(); assert(parameterIndices && "Parameter indices must be resolved"); SmallVector paramIndicesVector; @@ -2443,7 +2448,8 @@ class Serializer::DeclSerializer : public DeclVisitor { paramIndicesVector.push_back(parameterIndices->contains(i)); DerivativeDeclAttrLayout::emitRecord( S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId, - origDeclID, derivativeKind, paramIndicesVector); + origAccessorKind.hasValue(), rawAccessorKind, origDeclID, + derivativeKind, paramIndicesVector); return; } diff --git a/test/AutoDiff/Serialization/derivative_attr.swift b/test/AutoDiff/Serialization/derivative_attr.swift index 0a2094f4bb360..91677baa80e25 100644 --- a/test/AutoDiff/Serialization/derivative_attr.swift +++ b/test/AutoDiff/Serialization/derivative_attr.swift @@ -56,8 +56,11 @@ extension S { (self, { $0 }) } + // Note: qualified name base types are not yet serialized and are not printed + // when round-tripping. + // CHECK: @derivative(of: instanceMethod, wrt: (self, x)) - @derivative(of: instanceMethod, wrt: (self, x)) + @derivative(of: S.instanceMethod, wrt: (self, x)) func derivativeInstanceMethodWrtAll(_ x: S) -> (value: S, differential: (S, S) -> S) { (self, { (dself, dx) in self }) } @@ -81,7 +84,8 @@ extension S { extension S { var computedProperty: S { - self + get { self } + set {} } // CHECK: @derivative(of: computedProperty, wrt: self) @@ -89,11 +93,30 @@ extension S { func derivativeProperty() -> (value: S, differential: (S) -> S) { (self, { $0 }) } + + // CHECK: @derivative(of: computedProperty.get, wrt: self) + @derivative(of: computedProperty.get, wrt: self) + func derivativePropertyGetter() -> (value: S, pullback: (S) -> S) { + fatalError() + } + + // CHECK: @derivative(of: computedProperty.set, wrt: (self, newValue)) + @derivative(of: computedProperty.set, wrt: (self, newValue)) + mutating func derivativePropertySetter(_ newValue: S) -> ( + value: (), pullback: (inout S) -> S + ) { + fatalError() + } } // Test subscripts. extension S { + subscript() -> S { + get { self } + set {} + } + subscript(x: T) -> S { self } @@ -103,4 +126,18 @@ extension S { func derivativeSubscript(x: T) -> (value: S, differential: (S) -> S) { (self, { $0 }) } + + // CHECK: @derivative(of: subscript.get, wrt: self) + @derivative(of: subscript.get, wrt: self) + func derivativeSubscriptGetter() -> (value: S, pullback: (S) -> S) { + fatalError() + } + + // CHECK: @derivative(of: subscript.set, wrt: (self, newValue)) + @derivative(of: subscript.set, wrt: (self, newValue)) + mutating func derivativeSubscriptSetter(_ newValue: S) -> ( + value: (), pullback: (inout S) -> S + ) { + fatalError() + } } diff --git a/test/AutoDiff/Serialization/transpose_attr.swift b/test/AutoDiff/Serialization/transpose_attr.swift index 2d5c3470af40a..9f864862cda95 100644 --- a/test/AutoDiff/Serialization/transpose_attr.swift +++ b/test/AutoDiff/Serialization/transpose_attr.swift @@ -50,8 +50,11 @@ extension S { self + t } + // Note: qualified name base types are not yet serialized and are not printed + // when round-tripping. + // CHECK: @transpose(of: instanceMethod, wrt: self) - @transpose(of: instanceMethod, wrt: self) + @transpose(of: S.instanceMethod, wrt: self) static func transposeInstanceMethodWrtSelf(_ other: S, t: S) -> S { other + t }