1717#ifndef SWIFT_TYPES_H
1818#define SWIFT_TYPES_H
1919
20+ #include " swift/AST/AutoDiff.h"
2021#include " swift/AST/DeclContext.h"
2122#include " swift/AST/GenericParamKey.h"
2223#include " swift/AST/Identifier.h"
@@ -300,8 +301,8 @@ class alignas(1 << TypeAlignInBits) TypeBase {
300301 }
301302
302303protected:
303- enum { NumAFTExtInfoBits = 6 };
304- enum { NumSILExtInfoBits = 6 };
304+ enum { NumAFTExtInfoBits = 8 };
305+ enum { NumSILExtInfoBits = 8 };
305306 union { uint64_t OpaqueBits;
306307
307308 SWIFT_INLINE_BITFIELD_BASE (TypeBase, bitmax (NumTypeKindBits,8 ) +
@@ -2875,14 +2876,16 @@ class AnyFunctionType : public TypeBase {
28752876 // If bits are added or removed, then TypeBase::AnyFunctionTypeBits
28762877 // and NumMaskBits must be updated, and they must match.
28772878 //
2878- // |representation|noEscape|throws|
2879- // | 0 .. 3 | 4 | 5 |
2879+ // |representation|noEscape|throws|differentiability|
2880+ // | 0 .. 3 | 4 | 5 | 6 .. 7 |
28802881 //
28812882 enum : unsigned {
2882- RepresentationMask = 0xF << 0 ,
2883- NoEscapeMask = 1 << 4 ,
2884- ThrowsMask = 1 << 5 ,
2885- NumMaskBits = 6
2883+ RepresentationMask = 0xF << 0 ,
2884+ NoEscapeMask = 1 << 4 ,
2885+ ThrowsMask = 1 << 5 ,
2886+ DifferentiabilityMaskOffset = 6 ,
2887+ DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
2888+ NumMaskBits = 8
28862889 };
28872890
28882891 unsigned Bits; // Naturally sized for speed.
@@ -2905,13 +2908,24 @@ class AnyFunctionType : public TypeBase {
29052908 // Constructor with no defaults.
29062909 ExtInfo (Representation Rep,
29072910 bool IsNoEscape,
2908- bool Throws)
2911+ bool Throws,
2912+ DifferentiabilityKind DiffKind)
29092913 : ExtInfo(Rep, Throws) {
29102914 Bits |= (IsNoEscape ? NoEscapeMask : 0 );
2915+ Bits |= ((unsigned )DiffKind << DifferentiabilityMaskOffset) &
2916+ DifferentiabilityMask;
29112917 }
29122918
29132919 bool isNoEscape () const { return Bits & NoEscapeMask; }
29142920 bool throws () const { return Bits & ThrowsMask; }
2921+ bool isDifferentiable () const {
2922+ return getDifferentiabilityKind () >
2923+ DifferentiabilityKind::NonDifferentiable;
2924+ }
2925+ DifferentiabilityKind getDifferentiabilityKind () const {
2926+ return DifferentiabilityKind ((Bits & DifferentiabilityMask) >>
2927+ DifferentiabilityMaskOffset);
2928+ }
29152929 Representation getRepresentation () const {
29162930 unsigned rawRep = Bits & RepresentationMask;
29172931 assert (rawRep <= unsigned (Representation::Last)
@@ -3069,6 +3083,11 @@ class AnyFunctionType : public TypeBase {
30693083 return getExtInfo ().throws ();
30703084 }
30713085
3086+ bool isDifferentiable () const { return getExtInfo ().isDifferentiable (); }
3087+ DifferentiabilityKind getDifferentiabilityKind () const {
3088+ return getExtInfo ().getDifferentiabilityKind ();
3089+ }
3090+
30723091 // / Returns a new function type exactly like this one but with the ExtInfo
30733092 // / replaced.
30743093 AnyFunctionType *withExtInfo (ExtInfo info) const ;
@@ -3716,14 +3735,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37163735 // If bits are added or removed, then TypeBase::SILFunctionTypeBits
37173736 // and NumMaskBits must be updated, and they must match.
37183737
3719- // |representation|pseudogeneric| noescape |
3720- // | 0 .. 3 | 4 | 5 |
3738+ // |representation|pseudogeneric| noescape |differentiability|
3739+ // | 0 .. 3 | 4 | 5 | 6 .. 7 |
37213740 //
37223741 enum : unsigned {
37233742 RepresentationMask = 0xF << 0 ,
37243743 PseudogenericMask = 1 << 4 ,
37253744 NoEscapeMask = 1 << 5 ,
3726- NumMaskBits = 6
3745+ DifferentiabilityMaskOffset = 6 ,
3746+ DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
3747+ NumMaskBits = 8
37273748 };
37283749
37293750 unsigned Bits; // Naturally sized for speed.
@@ -3737,10 +3758,13 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37373758 ExtInfo () : Bits(0 ) { }
37383759
37393760 // Constructor for polymorphic type.
3740- ExtInfo (Representation rep, bool isPseudogeneric, bool isNoEscape) {
3761+ ExtInfo (Representation rep, bool isPseudogeneric, bool isNoEscape,
3762+ DifferentiabilityKind diffKind) {
37413763 Bits = ((unsigned ) rep) |
37423764 (isPseudogeneric ? PseudogenericMask : 0 ) |
3743- (isNoEscape ? NoEscapeMask : 0 );
3765+ (isNoEscape ? NoEscapeMask : 0 ) |
3766+ (((unsigned )diffKind << DifferentiabilityMaskOffset) &
3767+ DifferentiabilityMask);
37443768 }
37453769
37463770 // / Is this function pseudo-generic? A pseudo-generic function
@@ -3750,6 +3774,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37503774 // Is this function guaranteed to be no-escape by the type system?
37513775 bool isNoEscape () const { return Bits & NoEscapeMask; }
37523776
3777+ bool isDifferentiable () const {
3778+ return getDifferentiabilityKind () !=
3779+ DifferentiabilityKind::NonDifferentiable;
3780+ }
3781+
3782+ DifferentiabilityKind getDifferentiabilityKind () const {
3783+ return DifferentiabilityKind ((Bits & DifferentiabilityMask) >>
3784+ DifferentiabilityMaskOffset);
3785+ }
3786+
37533787 // / What is the abstract representation of this function value?
37543788 Representation getRepresentation () const {
37553789 return Representation (Bits & RepresentationMask);
@@ -4154,6 +4188,11 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
41544188 getRepresentation () == SILFunctionTypeRepresentation::Thick;
41554189 }
41564190
4191+ bool isDifferentiable () const { return getExtInfo ().isDifferentiable (); }
4192+ DifferentiabilityKind getDifferentiabilityKind () const {
4193+ return getExtInfo ().getDifferentiabilityKind ();
4194+ }
4195+
41574196 bool isNoReturnFunction (SILModule &M) const ; // Defined in SILType.cpp
41584197
41594198 // / Create a SILFunctionType with the same parameters, results, and attributes as this one, but with
0 commit comments