-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[AutoDiff] Fix differentiation for non-wrt inout parameters.
#33304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
lib/SIL/IR/SILFunctionType.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick paste of solution ideas from TF-1305:
- Add new SILParameterDifferentiability kinds, e.g.
InoutParameterNotDifferentiableParameter(case 3) andInoutParameterNotDifferentiableResult(case 4).- Consider whether this needs to be exposed in AST
@differentiablefunction types.
- Consider whether this needs to be exposed in AST
- Find some way to store parameter/result differentiability in
SILFunctionTypeinstead of individualSILParameterInfo/SILResultInfo.
Fix SIL differential function type calculation to handle non-wrt `inout` parameters. Patch `SILFunctionType::getDifferentiabilityResultIndices` to prevent returning empty result indices for `@differentiable` function types with no formal results where all `inout` parameters are `@noDerivative`. TF-1305 tracks a robust fix. Resolves SR-13305. Exposes TF-1305: parameter/result differentiability hole for `inout` parameters.
|
@swift-ci Please test |
lib/SIL/IR/SILFunctionType.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rxwei mentioned a nuanced point, which is that there are no clear use cases for inout parameters that are differentiability parameters but not differentiability results (case 4).
That case is indeed not currently expressible, so we can deprioritize support for it: I clarified that in the comment. Supporting that case would involve something like adding a results: clause to @differentiable declaration attribute:
// Pseudo-syntax for supporting case 4:
// `inout` parameter that is a differentiability parameter but not a differentiability result.
@differentiable(wrt: inoutParam, results: return)
func foo(_ inoutParam: inout Float) -> Float {
return inoutParam
}The focus of this PR is supporting inout parameters that are differentiability results but that are not differentiability parameters (case 3), which is currently expressible and reasonable to support.
// Case 3: `inout` parameter that is a differentiability result but not a differentiability parameter.
@differentiable(wrt: x)
func foo(_ x: Float, _ inoutParam: inout Float) {
inoutParam = x * x
}… but not diff. results. There are no clear use cases for `inout` parameters that are differentiability parameters but not differentiability results, so we can de-prioritize support for it. Clarify this in the comment regarding TF-1305.
|
@swift-ci Please test |
|
Build failed |
|
Build failed |
Fix SIL differential function type calculation to handle non-wrt
inoutparameters.
Patch
SILFunctionType::getDifferentiabilityResultIndicesto prevent returningempty result indices for
@differentiablefunction types with no formal resultswhere all
inoutparameters are@noDerivative. TF-1305 tracks a robust fix.Resolves SR-13305.
Exposes TF-1305: parameter/result differentiability hole for
inoutparameters.