Skip to content

Commit c3c776a

Browse files
committed
Factor out common code to ASTContext
1 parent d864dcd commit c3c776a

File tree

4 files changed

+62
-89
lines changed

4 files changed

+62
-89
lines changed

include/swift/AST/ASTContext.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,16 @@ class ASTContext final {
11241124
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
11251125
llvm::SetVector<AutoDiffConfig> &results);
11261126

1127+
/// Given `Optional<T>.TangentVector` type, retreive the
1128+
/// `Optional<T>.TangentVector.init` declaration.
1129+
ConstructorDecl *getOptionalTanInitDecl(CanType optionalTanType);
1130+
1131+
/// Optional<T>.TangentVector is a struct with a single
1132+
/// Optional<T.TangentVector> `value` property. This is an implementation
1133+
/// detail of OptionalDifferentiation.swift. Retreive `VarDecl` corresponding
1134+
/// to this property.
1135+
VarDecl *getOptionalTanValueDecl(CanType optionalTanType);
1136+
11271137
/// Retrieve the next macro expansion discriminator within the given
11281138
/// name and context.
11291139
unsigned getNextMacroDiscriminator(MacroDiscriminatorContext context,

lib/AST/ASTContext.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2245,6 +2245,44 @@ void ASTContext::loadObjCMethods(
22452245
}
22462246
}
22472247

2248+
ConstructorDecl *ASTContext::getOptionalTanInitDecl(CanType optionalTanType) {
2249+
auto *optionalTanDecl = optionalTanType.getNominalOrBoundGenericNominal();
2250+
// Look up the `Optional<T>.TangentVector.init` declaration.
2251+
auto initLookup =
2252+
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
2253+
ConstructorDecl *constructorDecl = nullptr;
2254+
for (auto *candidate : initLookup) {
2255+
auto candidateModule = candidate->getModuleContext();
2256+
if (candidateModule->getName() == Id_Differentiation ||
2257+
candidateModule->isStdlibModule()) {
2258+
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
2259+
constructorDecl = cast<ConstructorDecl>(candidate);
2260+
#ifdef NDEBUG
2261+
break;
2262+
#endif
2263+
}
2264+
}
2265+
assert(constructorDecl && "No `Optional.TangentVector.init`");
2266+
2267+
return constructorDecl;
2268+
}
2269+
2270+
VarDecl *ASTContext::getOptionalTanValueDecl(CanType optionalTanType) {
2271+
// TODO: Maybe it would be better to have getters / setters here that we
2272+
// can call and hide this implementation detail?
2273+
StructDecl *optStructDecl = optionalTanType.getStructOrBoundGenericStruct();
2274+
assert(optStructDecl && "Unexpected type of Optional.TangentVector");
2275+
2276+
ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties();
2277+
assert(properties.size() == 1 && "Unexpected type of Optional.TangentVector");
2278+
VarDecl *wrappedValueVar = properties[0];
2279+
2280+
assert(wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() ==
2281+
getOptionalDecl() && "Unexpected type of Optional.TangentVector");
2282+
2283+
return wrappedValueVar;
2284+
}
2285+
22482286
void ASTContext::loadDerivativeFunctionConfigurations(
22492287
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
22502288
llvm::SetVector<AutoDiffConfig> &results) {

lib/SILGen/SILGenPoly.cpp

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -301,24 +301,8 @@ SILGenFunction::emitTransformExistential(SILLocation loc,
301301
ManagedValue SILGenFunction::emitTangentVectorToOptionalTangentVector(
302302
SILLocation loc, ManagedValue input, CanType wrappedType, CanType inputType,
303303
CanType outputType, SGFContext ctxt) {
304-
auto *optionalTanDecl = outputType.getNominalOrBoundGenericNominal();
305304
// Look up the `Optional<T>.TangentVector.init` declaration.
306-
auto initLookup =
307-
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
308-
ConstructorDecl *constructorDecl = nullptr;
309-
for (auto *candidate : initLookup) {
310-
auto candidateModule = candidate->getModuleContext();
311-
if (candidateModule->getName() ==
312-
getASTContext().Id_Differentiation ||
313-
candidateModule->isStdlibModule()) {
314-
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
315-
constructorDecl = cast<ConstructorDecl>(candidate);
316-
#ifdef NDEBUG
317-
break;
318-
#endif
319-
}
320-
}
321-
assert(constructorDecl && "No `Optional.TangentVector.init`");
305+
auto *constructorDecl = getASTContext().getOptionalTanInitDecl(outputType);
322306

323307
// `Optional<T.TangentVector>`
324308
CanType optionalOfWrappedTanType = inputType.wrapInOptionalType();
@@ -346,26 +330,16 @@ ManagedValue SILGenFunction::emitOptionalTangentVectorToTangentVector(
346330
SILLocation loc, ManagedValue input, CanType wrappedType, CanType inputType,
347331
CanType outputType, SGFContext ctxt) {
348332
// Optional<T>.TangentVector should be a struct with a single
349-
// Optional<T.TangentVector> property. This is an implementation detail of
350-
// OptionalDifferentiation.swift
351-
// TODO: Maybe it would be better to have getters / setters here that we
352-
// can call and hide this implementation detail?
353-
StructDecl *optStructDecl = inputType.getStructOrBoundGenericStruct();
354-
assert(optStructDecl && "Unexpected type of Optional.TangentVector");
355-
356-
ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties();
357-
assert(properties.size() == 1 && "Unexpected type of Optional.TangentVector");
358-
VarDecl *wrappedValueVar = properties[0];
359-
360-
assert(wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() ==
361-
getASTContext().getOptionalDecl() &&
362-
"Unexpected type of Optional.TangentVector");
363-
364-
FormalEvaluationScope scope(*this);
365-
333+
// Optional<T.TangentVector> `value` property. This is an implementation
334+
// detail of OptionalDifferentiation.swift
335+
// TODO: Maybe it would be better to have explicit getters / setters here that we can
336+
// call and hide this implementation detail?
337+
VarDecl *wrappedValueVar = getASTContext().getOptionalTanValueDecl(inputType);
366338
// `Optional<T.TangentVector>`
367339
CanType optionalOfWrappedTanType = outputType.wrapInOptionalType();
368340

341+
FormalEvaluationScope scope(*this);
342+
369343
auto sig = wrappedValueVar->getDeclContext()->getGenericSignatureOfContext();
370344
auto *diffProto =
371345
getASTContext().getProtocol(KnownProtocolKind::Differentiable);

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 6 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,27 +1852,9 @@ class PullbackCloner::Implementation final
18521852

18531853
auto adjOpt = getAdjointValue(bb, ei);
18541854
auto adjStruct = materializeAdjointDirect(adjOpt, loc);
1855-
StructDecl *adjStructDecl =
1856-
adjStruct->getType().getStructOrBoundGenericStruct();
1857-
1858-
VarDecl *adjOptVar = nullptr;
1859-
if (adjStructDecl) {
1860-
ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties();
1861-
adjOptVar = properties.size() == 1 ? properties[0] : nullptr;
1862-
}
1863-
1864-
EnumDecl *adjOptDecl =
1865-
adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum()
1866-
: nullptr;
1867-
1868-
// Optional<T>.TangentVector should be a struct with a single
1869-
// Optional<T.TangentVector> property. This is an implementation detail of
1870-
// OptionalDifferentiation.swift
1871-
// TODO: Maybe it would be better to have getters / setters here that we
1872-
// can call and hide this implementation detail?
1873-
if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
1874-
llvm_unreachable("Unexpected type of Optional.TangentVector");
18751855

1856+
VarDecl *adjOptVar =
1857+
getASTContext().getOptionalTanValueDecl(adjStruct->getType().getASTType());
18761858
auto *adjVal = builder.createStructExtract(loc, adjStruct, adjOptVar);
18771859

18781860
EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
@@ -1931,24 +1913,8 @@ class PullbackCloner::Implementation final
19311913
}
19321914

19331915
SILValue adjDest = getAdjointBuffer(bb, origEnum);
1934-
StructDecl *adjStructDecl =
1935-
adjDest->getType().getStructOrBoundGenericStruct();
1936-
1937-
VarDecl *adjOptVar = nullptr;
1938-
if (adjStructDecl) {
1939-
ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties();
1940-
adjOptVar = properties.size() == 1 ? properties[0] : nullptr;
1941-
}
1942-
1943-
EnumDecl *adjOptDecl =
1944-
adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum()
1945-
: nullptr;
1946-
1947-
// Optional<T>.TangentVector should be a struct with a single
1948-
// Optional<T.TangentVector> property. This is an implementation detail of
1949-
// OptionalDifferentiation.swift
1950-
if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
1951-
llvm_unreachable("Unexpected type of Optional.TangentVector");
1916+
VarDecl *adjOptVar =
1917+
getASTContext().getOptionalTanValueDecl(adjDest->getType().getASTType());
19521918

19531919
SILLocation loc = origData->getLoc();
19541920
StructElementAddrInst *adjOpt =
@@ -2678,24 +2644,9 @@ AllocStackInst *PullbackCloner::Implementation::createOptionalAdjoint(
26782644
auto optionalOfWrappedTanType = SILType::getOptionalType(wrappedTanType);
26792645
// `Optional<T>.TangentVector`
26802646
auto optionalTanTy = getRemappedTangentType(optionalTy);
2681-
auto *optionalTanDecl = optionalTanTy.getNominalOrBoundGenericNominal();
26822647
// Look up the `Optional<T>.TangentVector.init` declaration.
2683-
auto initLookup =
2684-
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
2685-
ConstructorDecl *constructorDecl = nullptr;
2686-
for (auto *candidate : initLookup) {
2687-
auto candidateModule = candidate->getModuleContext();
2688-
if (candidateModule->getName() ==
2689-
builder.getASTContext().Id_Differentiation ||
2690-
candidateModule->isStdlibModule()) {
2691-
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
2692-
constructorDecl = cast<ConstructorDecl>(candidate);
2693-
#ifdef NDEBUG
2694-
break;
2695-
#endif
2696-
}
2697-
}
2698-
assert(constructorDecl && "No `Optional.TangentVector.init`");
2648+
ConstructorDecl *constructorDecl =
2649+
getASTContext().getOptionalTanInitDecl(optionalTanTy.getASTType());
26992650

27002651
// Allocate a local buffer for the `Optional` adjoint value.
27012652
auto *optTanAdjBuf = builder.createAllocStack(pbLoc, optionalTanTy);

0 commit comments

Comments
 (0)