@@ -1852,27 +1852,9 @@ class PullbackCloner::Implementation final
1852
1852
1853
1853
auto adjOpt = getAdjointValue (bb, ei);
1854
1854
auto adjStruct = materializeAdjointDirect (adjOpt, loc);
1855
- StructDecl *adjStructDecl =
1856
- adjStruct->getType ().getStructOrBoundGenericStruct ();
1857
-
1858
- VarDecl *adjOptVar = nullptr ;
1859
- if (adjStructDecl) {
1860
- ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties ();
1861
- adjOptVar = properties.size () == 1 ? properties[0 ] : nullptr ;
1862
- }
1863
-
1864
- EnumDecl *adjOptDecl =
1865
- adjOptVar ? adjOptVar->getTypeInContext ()->getEnumOrBoundGenericEnum ()
1866
- : nullptr ;
1867
-
1868
- // Optional<T>.TangentVector should be a struct with a single
1869
- // Optional<T.TangentVector> property. This is an implementation detail of
1870
- // OptionalDifferentiation.swift
1871
- // TODO: Maybe it would be better to have getters / setters here that we
1872
- // can call and hide this implementation detail?
1873
- if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
1874
- llvm_unreachable (" Unexpected type of Optional.TangentVector" );
1875
1855
1856
+ VarDecl *adjOptVar =
1857
+ getASTContext ().getOptionalTanValueDecl (adjStruct->getType ().getASTType ());
1876
1858
auto *adjVal = builder.createStructExtract (loc, adjStruct, adjOptVar);
1877
1859
1878
1860
EnumElementDecl *someElemDecl = getASTContext ().getOptionalSomeDecl ();
@@ -1931,24 +1913,8 @@ class PullbackCloner::Implementation final
1931
1913
}
1932
1914
1933
1915
SILValue adjDest = getAdjointBuffer (bb, origEnum);
1934
- StructDecl *adjStructDecl =
1935
- adjDest->getType ().getStructOrBoundGenericStruct ();
1936
-
1937
- VarDecl *adjOptVar = nullptr ;
1938
- if (adjStructDecl) {
1939
- ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties ();
1940
- adjOptVar = properties.size () == 1 ? properties[0 ] : nullptr ;
1941
- }
1942
-
1943
- EnumDecl *adjOptDecl =
1944
- adjOptVar ? adjOptVar->getTypeInContext ()->getEnumOrBoundGenericEnum ()
1945
- : nullptr ;
1946
-
1947
- // Optional<T>.TangentVector should be a struct with a single
1948
- // Optional<T.TangentVector> property. This is an implementation detail of
1949
- // OptionalDifferentiation.swift
1950
- if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
1951
- llvm_unreachable (" Unexpected type of Optional.TangentVector" );
1916
+ VarDecl *adjOptVar =
1917
+ getASTContext ().getOptionalTanValueDecl (adjDest->getType ().getASTType ());
1952
1918
1953
1919
SILLocation loc = origData->getLoc ();
1954
1920
StructElementAddrInst *adjOpt =
@@ -2678,24 +2644,9 @@ AllocStackInst *PullbackCloner::Implementation::createOptionalAdjoint(
2678
2644
auto optionalOfWrappedTanType = SILType::getOptionalType (wrappedTanType);
2679
2645
// `Optional<T>.TangentVector`
2680
2646
auto optionalTanTy = getRemappedTangentType (optionalTy);
2681
- auto *optionalTanDecl = optionalTanTy.getNominalOrBoundGenericNominal ();
2682
2647
// Look up the `Optional<T>.TangentVector.init` declaration.
2683
- auto initLookup =
2684
- optionalTanDecl->lookupDirect (DeclBaseName::createConstructor ());
2685
- ConstructorDecl *constructorDecl = nullptr ;
2686
- for (auto *candidate : initLookup) {
2687
- auto candidateModule = candidate->getModuleContext ();
2688
- if (candidateModule->getName () ==
2689
- builder.getASTContext ().Id_Differentiation ||
2690
- candidateModule->isStdlibModule ()) {
2691
- assert (!constructorDecl && " Multiple `Optional.TangentVector.init`s" );
2692
- constructorDecl = cast<ConstructorDecl>(candidate);
2693
- #ifdef NDEBUG
2694
- break ;
2695
- #endif
2696
- }
2697
- }
2698
- assert (constructorDecl && " No `Optional.TangentVector.init`" );
2648
+ ConstructorDecl *constructorDecl =
2649
+ getASTContext ().getOptionalTanInitDecl (optionalTanTy.getASTType ());
2699
2650
2700
2651
// Allocate a local buffer for the `Optional` adjoint value.
2701
2652
auto *optTanAdjBuf = builder.createAllocStack (pbLoc, optionalTanTy);
0 commit comments