@@ -315,15 +315,12 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy,
315315// / Collects the semantic results of the given function type in
316316// / `originalResults`. The semantic results are formal results followed by
317317// / `inout` parameters, in type order.
318- // TODO(TF-983): Generalize to support multiple `inout` parameters. The current
319- // singular `inoutParam` and `isWrtInoutParameter` are hacky.
320318static void
321319getSemanticResults (SILFunctionType *functionType, IndexSubset *parameterIndices,
322- Optional<SILParameterInfo> &inoutParam,
323- bool &isWrtInoutParameter,
320+ IndexSubset *&inoutParameterIndices,
324321 SmallVectorImpl<SILResultInfo> &originalResults) {
325- inoutParam = None ;
326- isWrtInoutParameter = false ;
322+ auto &C = functionType-> getASTContext () ;
323+ SmallVector< unsigned , 4 > inoutParamIndices ;
327324 // Collect original formal results.
328325 originalResults.append (functionType->getResults ().begin (),
329326 functionType->getResults ().end ());
@@ -332,11 +329,12 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
332329 auto param = functionType->getParameters ()[i];
333330 if (!param.isIndirectInOut ())
334331 continue ;
335- inoutParam = param;
336- isWrtInoutParameter = parameterIndices->contains (i);
332+ inoutParamIndices.push_back (i);
337333 originalResults.push_back (
338334 SILResultInfo (param.getInterfaceType (), ResultConvention::Indirect));
339335 }
336+ inoutParameterIndices =
337+ IndexSubset::get (C, parameterIndices->getCapacity (), inoutParamIndices);
340338}
341339
342340// / Returns the differential type for the given original function type,
@@ -402,11 +400,10 @@ static CanSILFunctionType getAutoDiffDifferentialType(
402400 SmallVector<Type, 4 > substReplacements;
403401 SmallVector<ProtocolConformanceRef, 4 > substConformances;
404402
405- Optional<SILParameterInfo> inoutParam = None;
406- bool isWrtInoutParameter = false ;
403+ IndexSubset *inoutParamIndices;
407404 SmallVector<SILResultInfo, 2 > originalResults;
408- getSemanticResults (originalFnTy, parameterIndices, inoutParam ,
409- isWrtInoutParameter, originalResults);
405+ getSemanticResults (originalFnTy, parameterIndices, inoutParamIndices ,
406+ originalResults);
410407
411408 SmallVector<SILParameterInfo, 4 > diffParams;
412409 getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
@@ -430,7 +427,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
430427 }
431428 }
432429 SmallVector<SILResultInfo, 1 > differentialResults;
433- if (!inoutParam || !isWrtInoutParameter ) {
430+ if (inoutParamIndices-> isEmpty () ) {
434431 for (auto resultIndex : resultIndices->getIndices ()) {
435432 auto &result = originalResults[resultIndex];
436433 auto resultTan =
@@ -480,11 +477,10 @@ static CanSILFunctionType getAutoDiffPullbackType(
480477 SmallVector<Type, 4 > substReplacements;
481478 SmallVector<ProtocolConformanceRef, 4 > substConformances;
482479
483- Optional<SILParameterInfo> inoutParam = None;
484- bool isWrtInoutParameter = false ;
480+ IndexSubset *inoutParamIndices;
485481 SmallVector<SILResultInfo, 2 > originalResults;
486- getSemanticResults (originalFnTy, parameterIndices, inoutParam ,
487- isWrtInoutParameter, originalResults);
482+ getSemanticResults (originalFnTy, parameterIndices, inoutParamIndices ,
483+ originalResults);
488484
489485 // Given a type, returns its formal SIL parameter info.
490486 auto getTangentParameterConventionForOriginalResult =
@@ -551,27 +547,11 @@ static CanSILFunctionType getAutoDiffPullbackType(
551547 return conv;
552548 };
553549
550+ // Collect pullback parameters.
554551 SmallVector<SILParameterInfo, 1 > pullbackParams;
555- if (inoutParam) {
556- auto paramTan = inoutParam->getInterfaceType ()->getAutoDiffTangentSpace (
557- lookupConformance);
558- assert (paramTan && " Parameter type does not have a tangent space?" );
559- auto paramTanConvention = isWrtInoutParameter
560- ? inoutParam->getConvention ()
561- : ParameterConvention::Indirect_In_Guaranteed;
562- auto paramTanType = paramTan->getCanonicalType ();
563- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
564- pullbackParams.push_back (
565- SILParameterInfo (paramTanType, paramTanConvention));
566- } else {
567- auto gpIndex = substGenericParams.size ();
568- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
569- substGenericParams.push_back (gpType);
570- substReplacements.push_back (paramTanType);
571- pullbackParams.push_back ({gpType, paramTanConvention});
572- }
573- } else {
574- for (auto resultIndex : resultIndices->getIndices ()) {
552+ for (auto resultIndex : resultIndices->getIndices ()) {
553+ // Handle formal original result.
554+ if (resultIndex < originalFnTy->getNumResults ()) {
575555 auto &origRes = originalResults[resultIndex];
576556 auto resultTan = origRes.getInterfaceType ()->getAutoDiffTangentSpace (
577557 lookupConformance);
@@ -590,12 +570,46 @@ static CanSILFunctionType getAutoDiffPullbackType(
590570 substReplacements.push_back (resultTanType);
591571 pullbackParams.push_back ({gpType, paramTanConvention});
592572 }
573+ continue ;
574+ }
575+ // Handle original `inout` parameter.
576+ auto inoutParamIndex = resultIndex - originalFnTy->getNumResults ();
577+ auto inoutParamIt = std::next (
578+ originalFnTy->getIndirectMutatingParameters ().begin (), inoutParamIndex);
579+ auto paramIndex =
580+ std::distance (originalFnTy->getParameters ().begin (), &*inoutParamIt);
581+ auto inoutParam = originalFnTy->getParameters ()[paramIndex];
582+ auto paramTan = inoutParam.getInterfaceType ()->getAutoDiffTangentSpace (
583+ lookupConformance);
584+ assert (paramTan && " Parameter type does not have a tangent space?" );
585+ // The pullback parameter convention depends on whether the original `inout`
586+ // paramater is a differentiability parameter.
587+ // - If yes, the pullback parameter convention is `@inout`.
588+ // - If no, the pullback parameter convention is `@in_guaranteed`.
589+ bool isWrtInoutParameter = parameterIndices->contains (paramIndex);
590+ auto paramTanConvention = isWrtInoutParameter
591+ ? inoutParam.getConvention ()
592+ : ParameterConvention::Indirect_In_Guaranteed;
593+ auto paramTanType = paramTan->getCanonicalType ();
594+ if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
595+ pullbackParams.push_back (
596+ SILParameterInfo (paramTanType, paramTanConvention));
597+ } else {
598+ auto gpIndex = substGenericParams.size ();
599+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
600+ substGenericParams.push_back (gpType);
601+ substReplacements.push_back (paramTanType);
602+ pullbackParams.push_back ({gpType, paramTanConvention});
593603 }
594604 }
605+
606+ // Collect pullback results.
595607 SmallVector<SILParameterInfo, 4 > diffParams;
596608 getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
597609 SmallVector<SILResultInfo, 8 > pullbackResults;
598610 for (auto ¶m : diffParams) {
611+ // Skip `inout` parameters, which semantically behave as original results
612+ // and always appear as pullback parameters.
599613 if (param.isIndirectInOut ())
600614 continue ;
601615 auto paramTan =
0 commit comments