1717#ifndef SWIFT_TYPES_H
1818#define SWIFT_TYPES_H
1919
20+ #include " swift/AST/AutoDiff.h"
2021#include " swift/AST/DeclContext.h"
2122#include " swift/AST/GenericParamKey.h"
2223#include " swift/AST/Identifier.h"
@@ -301,8 +302,8 @@ class alignas(1 << TypeAlignInBits) TypeBase {
301302 }
302303
303304protected:
304- enum { NumAFTExtInfoBits = 6 };
305- enum { NumSILExtInfoBits = 6 };
305+ enum { NumAFTExtInfoBits = 8 };
306+ enum { NumSILExtInfoBits = 8 };
306307 union { uint64_t OpaqueBits;
307308
308309 SWIFT_INLINE_BITFIELD_BASE (TypeBase, bitmax (NumTypeKindBits,8 ) +
@@ -2879,14 +2880,16 @@ class AnyFunctionType : public TypeBase {
28792880 // If bits are added or removed, then TypeBase::AnyFunctionTypeBits
28802881 // and NumMaskBits must be updated, and they must match.
28812882 //
2882- // |representation|noEscape|throws|
2883- // | 0 .. 3 | 4 | 5 |
2883+ // |representation|noEscape|throws|differentiability|
2884+ // | 0 .. 3 | 4 | 5 | 6 .. 7 |
28842885 //
28852886 enum : unsigned {
2886- RepresentationMask = 0xF << 0 ,
2887- NoEscapeMask = 1 << 4 ,
2888- ThrowsMask = 1 << 5 ,
2889- NumMaskBits = 6
2887+ RepresentationMask = 0xF << 0 ,
2888+ NoEscapeMask = 1 << 4 ,
2889+ ThrowsMask = 1 << 5 ,
2890+ DifferentiabilityMaskOffset = 6 ,
2891+ DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
2892+ NumMaskBits = 8
28902893 };
28912894
28922895 unsigned Bits; // Naturally sized for speed.
@@ -2909,13 +2912,24 @@ class AnyFunctionType : public TypeBase {
29092912 // Constructor with no defaults.
29102913 ExtInfo (Representation Rep,
29112914 bool IsNoEscape,
2912- bool Throws)
2915+ bool Throws,
2916+ DifferentiabilityKind DiffKind)
29132917 : ExtInfo(Rep, Throws) {
29142918 Bits |= (IsNoEscape ? NoEscapeMask : 0 );
2919+ Bits |= ((unsigned )DiffKind << DifferentiabilityMaskOffset) &
2920+ DifferentiabilityMask;
29152921 }
29162922
29172923 bool isNoEscape () const { return Bits & NoEscapeMask; }
29182924 bool throws () const { return Bits & ThrowsMask; }
2925+ bool isDifferentiable () const {
2926+ return getDifferentiabilityKind () >
2927+ DifferentiabilityKind::NonDifferentiable;
2928+ }
2929+ DifferentiabilityKind getDifferentiabilityKind () const {
2930+ return DifferentiabilityKind ((Bits & DifferentiabilityMask) >>
2931+ DifferentiabilityMaskOffset);
2932+ }
29192933 Representation getRepresentation () const {
29202934 unsigned rawRep = Bits & RepresentationMask;
29212935 assert (rawRep <= unsigned (Representation::Last)
@@ -3073,6 +3087,11 @@ class AnyFunctionType : public TypeBase {
30733087 return getExtInfo ().throws ();
30743088 }
30753089
3090+ bool isDifferentiable () const { return getExtInfo ().isDifferentiable (); }
3091+ DifferentiabilityKind getDifferentiabilityKind () const {
3092+ return getExtInfo ().getDifferentiabilityKind ();
3093+ }
3094+
30763095 // / Returns a new function type exactly like this one but with the ExtInfo
30773096 // / replaced.
30783097 AnyFunctionType *withExtInfo (ExtInfo info) const ;
@@ -3731,14 +3750,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37313750 // If bits are added or removed, then TypeBase::SILFunctionTypeBits
37323751 // and NumMaskBits must be updated, and they must match.
37333752
3734- // |representation|pseudogeneric| noescape |
3735- // | 0 .. 3 | 4 | 5 |
3753+ // |representation|pseudogeneric| noescape |differentiability|
3754+ // | 0 .. 3 | 4 | 5 | 6 .. 7 |
37363755 //
37373756 enum : unsigned {
37383757 RepresentationMask = 0xF << 0 ,
37393758 PseudogenericMask = 1 << 4 ,
37403759 NoEscapeMask = 1 << 5 ,
3741- NumMaskBits = 6
3760+ DifferentiabilityMaskOffset = 6 ,
3761+ DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
3762+ NumMaskBits = 8
37423763 };
37433764
37443765 unsigned Bits; // Naturally sized for speed.
@@ -3752,10 +3773,13 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37523773 ExtInfo () : Bits(0 ) { }
37533774
37543775 // Constructor for polymorphic type.
3755- ExtInfo (Representation rep, bool isPseudogeneric, bool isNoEscape) {
3776+ ExtInfo (Representation rep, bool isPseudogeneric, bool isNoEscape,
3777+ DifferentiabilityKind diffKind) {
37563778 Bits = ((unsigned ) rep) |
37573779 (isPseudogeneric ? PseudogenericMask : 0 ) |
3758- (isNoEscape ? NoEscapeMask : 0 );
3780+ (isNoEscape ? NoEscapeMask : 0 ) |
3781+ (((unsigned )diffKind << DifferentiabilityMaskOffset) &
3782+ DifferentiabilityMask);
37593783 }
37603784
37613785 // / Is this function pseudo-generic? A pseudo-generic function
@@ -3765,6 +3789,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37653789 // Is this function guaranteed to be no-escape by the type system?
37663790 bool isNoEscape () const { return Bits & NoEscapeMask; }
37673791
3792+ bool isDifferentiable () const {
3793+ return getDifferentiabilityKind () !=
3794+ DifferentiabilityKind::NonDifferentiable;
3795+ }
3796+
3797+ DifferentiabilityKind getDifferentiabilityKind () const {
3798+ return DifferentiabilityKind ((Bits & DifferentiabilityMask) >>
3799+ DifferentiabilityMaskOffset);
3800+ }
3801+
37683802 // / What is the abstract representation of this function value?
37693803 Representation getRepresentation () const {
37703804 return Representation (Bits & RepresentationMask);
@@ -4169,6 +4203,11 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
41694203 getRepresentation () == SILFunctionTypeRepresentation::Thick;
41704204 }
41714205
4206+ bool isDifferentiable () const { return getExtInfo ().isDifferentiable (); }
4207+ DifferentiabilityKind getDifferentiabilityKind () const {
4208+ return getExtInfo ().getDifferentiabilityKind ();
4209+ }
4210+
41724211 bool isNoReturnFunction (SILModule &M) const ; // Defined in SILType.cpp
41734212
41744213 // / Create a SILFunctionType with the same parameters, results, and attributes as this one, but with
0 commit comments