Skip to content

Commit cf717ce

Browse files
committed
[AutoDiff] Closure specialization: specialize branch tracing enums
This patch contains part of the changes intended to resolve #68944. 1. Closure info gathering logic. 2. Branch tracing enum specialization logic. 3. Specialization of branch tracing enum basic block arguments in VJP. 4. Specialization of branch tracing enum payload basic block arguments in pullback. 5. C++ helpers for things which are not yet bridged to Swift, such as demangling-related logic.
1 parent 10085d6 commit cf717ce

File tree

12 files changed

+1117
-15
lines changed

12 files changed

+1117
-15
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/ClosureSpecialization.swift

Lines changed: 675 additions & 1 deletion
Large diffs are not rendered by default.

SwiftCompilerSources/Sources/Optimizer/Utilities/FunctionTest.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public func registerOptimizerTests() {
4444
addressOwnershipLiveRangeTest,
4545
argumentConventionsTest,
4646
getPullbackClosureInfoTest,
47+
getPullbackClosureInfoMultiBBTest,
4748
interiorLivenessTest,
4849
lifetimeDependenceRootTest,
4950
lifetimeDependenceScopeTest,
@@ -53,6 +54,9 @@ public func registerOptimizerTests() {
5354
localVariableReachingAssignmentsTest,
5455
rangeOverlapsPathTest,
5556
rewrittenCallerBodyTest,
57+
specializeBranchTracingEnums,
58+
specializeBTEArgInVjpBB,
59+
specializePayloadArgInPullbackBB,
5660
specializedFunctionSignatureAndBodyTest,
5761
variableIntroducerTest
5862
)

include/module.modulemap

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,8 @@ module OptimizerBridging {
2626
header "swift/SILOptimizer/OptimizerBridging.h"
2727
export *
2828
}
29+
30+
module AutoDiffClosureSpecializationBridging {
31+
header "swift/SILOptimizer/AutoDiffClosureSpecializationBridging.h"
32+
export *
33+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===--- AutoDiffClosureSpecializationBridging.h --------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2025 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef SWIFT_SILOPTIMIZER_ADCSBRIDGING_H
14+
#define SWIFT_SILOPTIMIZER_ADCSBRIDGING_H
15+
16+
#include "swift/AST/ASTBridging.h"
17+
#include "swift/Basic/BasicBridging.h"
18+
#include "swift/SIL/SILBridging.h"
19+
20+
SWIFT_BEGIN_NULLABILITY_ANNOTATIONS
21+
22+
SWIFT_IMPORT_UNSAFE BridgedType
23+
getBranchingTraceEnumLoweredType(BridgedDeclObj ed, BridgedFunction vjp);
24+
25+
SWIFT_IMPORT_UNSAFE BridgedType
26+
getBranchingTraceEnumLoweredType(BridgedEnumDecl ed, BridgedFunction vjp);
27+
28+
SWIFT_IMPORT_UNSAFE BridgedNullableGenericParamList
29+
cloneGenericParameters(BridgedASTContext ctx, BridgedDeclContext dc,
30+
BridgedCanGenericSignature sig);
31+
32+
SWIFT_IMPORT_UNSAFE BridgedSourceFile autodiffGetSourceFile(BridgedFunction f);
33+
34+
SWIFT_IMPORT_UNSAFE BridgedOwnedString getEnumDeclAsString(BridgedType bteType);
35+
36+
SWIFT_END_NULLABILITY_ANNOTATIONS
37+
38+
#endif

include/swift/SILOptimizer/Differentiation/ADContext.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,9 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc,
418418
llvm_unreachable("Invalid invoker kind"); // silences MSVC C4715
419419
}
420420

421+
/// Get the source file for the given `SILFunction`.
422+
SourceFile &getSourceFile(SILFunction *f);
423+
421424
} // end namespace autodiff
422425
} // end namespace swift
423426

include/swift/SILOptimizer/Differentiation/LinearMapInfo.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ class SILLoopInfo;
3030

3131
namespace autodiff {
3232

33+
inline SILType getLoweredTypeImpl(CanType ty, SILFunction *derivative,
34+
Lowering::TypeConverter &typeConverter) {
35+
Lowering::AbstractionPattern pattern(
36+
derivative->getLoweredFunctionType()->getSubstGenericSignature(), ty);
37+
return typeConverter.getLoweredType(pattern, ty,
38+
TypeExpansionContext::minimal());
39+
}
40+
3341
class ADContext;
3442

3543
/// Linear map struct and branching trace enum information for an original
@@ -161,13 +169,9 @@ class LinearMapInfo {
161169
/// the given original block.
162170
SILType getBranchingTraceEnumLoweredType(SILBasicBlock *origBB) const {
163171
auto *traceDecl = getBranchingTraceDecl(origBB);
164-
auto traceDeclType =
165-
traceDecl->getDeclaredInterfaceType()->getCanonicalType();
166-
Lowering::AbstractionPattern pattern(
167-
derivative->getLoweredFunctionType()->getSubstGenericSignature(),
168-
traceDeclType);
169-
return typeConverter.getLoweredType(pattern, traceDeclType,
170-
TypeExpansionContext::minimal());
172+
return getLoweredTypeImpl(
173+
traceDecl->getDeclaredInterfaceType()->getCanonicalType(), derivative,
174+
typeConverter);
171175
}
172176

173177
/// Returns the enum element in the given successor block's branching trace
@@ -199,6 +203,11 @@ class LinearMapInfo {
199203
}
200204
};
201205

206+
/// Clone the generic parameters of the given generic signature and return a new
207+
/// `GenericParamList`.
208+
GenericParamList *cloneGenericParameters(ASTContext &ctx, DeclContext *dc,
209+
CanGenericSignature sig);
210+
202211
} // end namespace autodiff
203212
} // end namespace swift
204213

lib/SILOptimizer/Differentiation/ADContext.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ ADContext::ADContext(SILModuleTransform &transform)
6060
: transform(transform), module(*transform.getModule()),
6161
passManager(*transform.getPassManager()) {}
6262

63-
/// Get the source file for the given `SILFunction`.
64-
static SourceFile &getSourceFile(SILFunction *f) {
63+
SourceFile &getSourceFile(SILFunction *f) {
6564
if (f->hasLocation())
6665
if (auto *declContext = f->getLocation().getAsDeclContext())
6766
if (auto *parentSourceFile = declContext->getParentSourceFile())

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,8 @@ namespace autodiff {
3333
// Local helpers
3434
//===----------------------------------------------------------------------===//
3535

36-
/// Clone the generic parameters of the given generic signature and return a new
37-
/// `GenericParamList`.
38-
static GenericParamList *cloneGenericParameters(ASTContext &ctx,
39-
DeclContext *dc,
40-
CanGenericSignature sig) {
36+
GenericParamList *cloneGenericParameters(ASTContext &ctx, DeclContext *dc,
37+
CanGenericSignature sig) {
4138
SmallVector<GenericTypeParamDecl *, 2> clonedParams;
4239
for (auto paramType : sig.getGenericParams()) {
4340
auto *clonedParam = GenericTypeParamDecl::createImplicit(
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//===--- AutoDiffClosureSpecializationBridging.cpp ------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2025 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#define DEBUG_TYPE "autodiff-closure-specialization-bridging"
14+
15+
#include "swift/SILOptimizer/AutoDiffClosureSpecializationBridging.h"
16+
#include "swift/SILOptimizer/Differentiation/ADContext.h"
17+
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
18+
19+
using namespace swift;
20+
21+
static SILType getBranchingTraceEnumLoweredTypeImpl(EnumDecl *ed,
22+
SILFunction &vjp) {
23+
return autodiff::getLoweredTypeImpl(
24+
ed->getDeclaredInterfaceType()->getCanonicalType(), &vjp,
25+
vjp.getModule().Types);
26+
}
27+
28+
BridgedType getBranchingTraceEnumLoweredType(BridgedEnumDecl ed,
29+
BridgedFunction vjp) {
30+
return getBranchingTraceEnumLoweredTypeImpl(ed.unbridged(),
31+
*vjp.getFunction());
32+
}
33+
34+
BridgedType getBranchingTraceEnumLoweredType(BridgedDeclObj ed,
35+
BridgedFunction vjp) {
36+
return getBranchingTraceEnumLoweredTypeImpl(ed.getAs<EnumDecl>(),
37+
*vjp.getFunction());
38+
}
39+
40+
BridgedNullableGenericParamList
41+
cloneGenericParameters(BridgedASTContext ctx, BridgedDeclContext dc,
42+
BridgedCanGenericSignature sig) {
43+
return autodiff::cloneGenericParameters(ctx.unbridged(), dc.unbridged(),
44+
sig.unbridged());
45+
}
46+
47+
BridgedSourceFile autodiffGetSourceFile(BridgedFunction f) {
48+
return {&autodiff::getSourceFile(f.getFunction())};
49+
}
50+
51+
BridgedOwnedString getEnumDeclAsString(BridgedType bteType) {
52+
std::string str;
53+
llvm::raw_string_ostream out(str);
54+
bteType.unbridged().getEnumOrBoundGenericEnum()->print(out);
55+
return BridgedOwnedString(/*stringToCopy=*/StringRef(str));
56+
}

lib/SILOptimizer/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
target_sources(swiftSILOptimizer PRIVATE
2+
AutoDiffClosureSpecializationBridging.cpp
23
BasicBlockOptUtils.cpp
34
CFGOptUtils.cpp
45
CanonicalizeInstruction.cpp

0 commit comments

Comments
 (0)