@@ -51,29 +51,45 @@ class AdjointValueBase {
5151 // / The type of this value as if it were materialized as a SIL value.
5252 SILType type;
5353
54+ using DebugInfo = std::pair<SILDebugLocation, SILDebugVariable>;
55+
56+ // / The debug location and variable info associated with the original value.
57+ Optional<DebugInfo> debugInfo;
58+
5459 // / The underlying value.
5560 union Value {
56- llvm::ArrayRef<AdjointValue> aggregate ;
61+ unsigned numAggregateElements ;
5762 SILValue concrete;
58- Value (llvm::ArrayRef<AdjointValue> v) : aggregate (v) {}
63+ Value (unsigned numAggregateElements)
64+ : numAggregateElements (numAggregateElements) {}
5965 Value (SILValue v) : concrete (v) {}
6066 Value () {}
6167 } value;
6268
69+ // Begins tail-allocated aggregate elements, if
70+ // `kind == AdjointValueKind::Aggregate`.
71+
6372 explicit AdjointValueBase (SILType type,
64- llvm::ArrayRef<AdjointValue> aggregate)
65- : kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {}
73+ llvm::ArrayRef<AdjointValue> aggregate,
74+ Optional<DebugInfo> debugInfo)
75+ : kind(AdjointValueKind::Aggregate), type(type), debugInfo(debugInfo),
76+ value(aggregate.size()) {
77+ MutableArrayRef<AdjointValue> tailElements (
78+ reinterpret_cast <AdjointValue *>(this + 1 ), aggregate.size ());
79+ std::uninitialized_copy (
80+ aggregate.begin (), aggregate.end (), tailElements.begin ());
81+ }
6682
67- explicit AdjointValueBase (SILValue v)
68- : kind(AdjointValueKind::Concrete), type(v->getType ()), value(v) {}
83+ explicit AdjointValueBase (SILValue v, Optional<DebugInfo> debugInfo)
84+ : kind(AdjointValueKind::Concrete), type(v->getType ()),
85+ debugInfo(debugInfo), value(v) {}
6986
70- explicit AdjointValueBase (SILType type)
71- : kind(AdjointValueKind::Zero), type(type) {}
87+ explicit AdjointValueBase (SILType type, Optional<DebugInfo> debugInfo )
88+ : kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
7289};
7390
74- // / A symbolic adjoint value that is capable of representing zero value 0 and
75- // / 1, in addition to a materialized SILValue. This is expected to be passed
76- // / around by value in most cases, as it's two words long.
91+ // / A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
92+ // / thereof.
7793class AdjointValue final {
7894
7995private:
@@ -85,26 +101,37 @@ class AdjointValue final {
85101 AdjointValueBase *operator ->() const { return base; }
86102 AdjointValueBase &operator *() const { return *base; }
87103
88- static AdjointValue createConcrete (llvm::BumpPtrAllocator &allocator,
89- SILValue value) {
90- return new (allocator.Allocate <AdjointValueBase>()) AdjointValueBase (value);
104+ using DebugInfo = AdjointValueBase::DebugInfo;
105+
106+ static AdjointValue createConcrete (
107+ llvm::BumpPtrAllocator &allocator, SILValue value,
108+ Optional<DebugInfo> debugInfo = None) {
109+ auto *buf = allocator.Allocate <AdjointValueBase>();
110+ return new (buf) AdjointValueBase (value, debugInfo);
91111 }
92112
93- static AdjointValue createZero (llvm::BumpPtrAllocator &allocator,
94- SILType type) {
95- return new (allocator.Allocate <AdjointValueBase>()) AdjointValueBase (type);
113+ static AdjointValue createZero (
114+ llvm::BumpPtrAllocator &allocator, SILType type,
115+ Optional<DebugInfo> debugInfo = None) {
116+ auto *buf = allocator.Allocate <AdjointValueBase>();
117+ return new (buf) AdjointValueBase (type, debugInfo);
96118 }
97119
98- static AdjointValue createAggregate (llvm::BumpPtrAllocator &allocator,
99- SILType type,
100- llvm::ArrayRef<AdjointValue> aggregate) {
101- return new (allocator.Allocate <AdjointValueBase>())
102- AdjointValueBase (type, aggregate);
120+ static AdjointValue createAggregate (
121+ llvm::BumpPtrAllocator &allocator, SILType type,
122+ ArrayRef<AdjointValue> elements,
123+ Optional<DebugInfo> debugInfo = None) {
124+ AdjointValue *buf = reinterpret_cast <AdjointValue *>(allocator.Allocate (
125+ sizeof (AdjointValueBase) + elements.size () * sizeof (AdjointValue),
126+ alignof (AdjointValueBase)));
127+ return new (buf) AdjointValueBase (type, elements, debugInfo);
103128 }
104129
105130 AdjointValueKind getKind () const { return base->kind ; }
106131 SILType getType () const { return base->type ; }
107132 CanType getSwiftType () const { return getType ().getASTType (); }
133+ Optional<DebugInfo> getDebugInfo () const { return base->debugInfo ; }
134+ void setDebugInfo (DebugInfo debugInfo) const { base->debugInfo = debugInfo; }
108135
109136 NominalTypeDecl *getAnyNominal () const {
110137 return getSwiftType ()->getAnyNominal ();
@@ -116,16 +143,18 @@ class AdjointValue final {
116143
117144 unsigned getNumAggregateElements () const {
118145 assert (isAggregate ());
119- return base->value .aggregate . size () ;
146+ return base->value .numAggregateElements ;
120147 }
121148
122149 AdjointValue getAggregateElement (unsigned i) const {
123- assert (isAggregate ());
124- return base->value .aggregate [i];
150+ return getAggregateElements ()[i];
125151 }
126152
127153 llvm::ArrayRef<AdjointValue> getAggregateElements () const {
128- return base->value .aggregate ;
154+ assert (isAggregate ());
155+ return {
156+ reinterpret_cast <const AdjointValue *>(base + 1 ),
157+ getNumAggregateElements ()};
129158 }
130159
131160 SILValue getConcreteValue () const {
@@ -143,15 +172,15 @@ class AdjointValue final {
143172 if (auto *decl =
144173 getType ().getASTType ()->getStructOrBoundGenericStruct ()) {
145174 interleave (
146- llvm::zip (decl->getStoredProperties (), base-> value . aggregate ),
175+ llvm::zip (decl->getStoredProperties (), getAggregateElements () ),
147176 [&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
148177 s << std::get<0 >(elt)->getName () << " : " ;
149178 std::get<1 >(elt).print (s);
150179 },
151180 [&s] { s << " , " ; });
152181 } else if (getType ().is <TupleType>()) {
153182 interleave (
154- base-> value . aggregate ,
183+ getAggregateElements () ,
155184 [&s](const AdjointValue &elt) { elt.print (s); },
156185 [&s] { s << " , " ; });
157186 } else {
0 commit comments