@@ -3375,42 +3375,20 @@ DynamicallyReplacedDeclRequest::evaluate(Evaluator &evaluator,
33753375 return nullptr ;
33763376}
33773377
3378- // / If the given type conforms to `Differentiable` in the given context, returns
3379- // / the `ProtocolConformanceRef`. Otherwise, returns an invalid
3380- // / `ProtocolConformanceRef`.
3381- // /
3382- // / This helper verifies that the `TangentVector` type witness is valid, in case
3383- // / the conformance has not been fully checked and the type witness cannot be
3384- // / resolved.
3385- static ProtocolConformanceRef getDifferentiableConformance (Type type,
3386- DeclContext *DC) {
3387- auto &ctx = type->getASTContext ();
3388- auto *differentiableProto =
3389- ctx.getProtocol (KnownProtocolKind::Differentiable);
3390- auto conf =
3391- TypeChecker::conformsToProtocol (type, differentiableProto, DC);
3392- if (!conf)
3393- return ProtocolConformanceRef ();
3394- // Try to get the `TangentVector` type witness, in case the conformance has
3395- // not been fully checked.
3396- Type tanType = conf.getTypeWitnessByName (type, ctx.Id_TangentVector );
3397- if (tanType.isNull () || tanType->hasError ())
3398- return ProtocolConformanceRef ();
3399- return conf;
3400- };
3401-
34023378// / Returns true if the given type conforms to `Differentiable` in the given
34033379// / contxt. If `tangentVectorEqualsSelf` is true, also check whether the given
34043380// / type satisfies `TangentVector == Self`.
34053381static bool conformsToDifferentiable (Type type, DeclContext *DC,
34063382 bool tangentVectorEqualsSelf = false ) {
3407- auto conf = getDifferentiableConformance (type, DC);
3383+ auto &ctx = type->getASTContext ();
3384+ auto *differentiableProto =
3385+ ctx.getProtocol (KnownProtocolKind::Differentiable);
3386+ auto conf = TypeChecker::conformsToProtocol (type, differentiableProto, DC);
34083387 if (conf.isInvalid ())
34093388 return false ;
34103389 if (!tangentVectorEqualsSelf)
34113390 return true ;
3412- auto &ctx = type->getASTContext ();
3413- Type tanType = conf.getTypeWitnessByName (type, ctx.Id_TangentVector );
3391+ auto tanType = conf.getTypeWitnessByName (type, ctx.Id_TangentVector );
34143392 return type->isEqual (tanType);
34153393};
34163394
@@ -4602,67 +4580,81 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
46024580 // Set the resolved differentiability parameter indices in the attribute.
46034581 attr->setParameterIndices (resolvedDiffParamIndices);
46044582
4605- // Get the original semantic result.
4606- llvm::SmallVector<AutoDiffSemanticFunctionResultType, 1 > originalResults;
4607- autodiff::getFunctionSemanticResultTypes (
4608- originalFnType, originalResults,
4609- derivative->getGenericEnvironmentOfContext ());
4610- // Check that original function has at least one semantic result, i.e.
4611- // that the original semantic result type is not `Void`.
4612- if (originalResults.empty ()) {
4613- diags
4614- .diagnose (attr->getLocation (), diag::autodiff_attr_original_void_result,
4615- derivative->getName ())
4616- .highlight (attr->getOriginalFunctionName ().Loc .getSourceRange ());
4617- attr->setInvalid ();
4618- return true ;
4619- }
4620- // Check that original function does not have multiple semantic results.
4621- if (originalResults.size () > 1 ) {
4622- diags
4623- .diagnose (attr->getLocation (),
4624- diag::autodiff_attr_original_multiple_semantic_results)
4625- .highlight (attr->getOriginalFunctionName ().Loc .getSourceRange ());
4626- attr->setInvalid ();
4627- return true ;
4628- }
4629- auto originalResult = originalResults.front ();
4630- auto originalResultType = originalResult.type ;
4631- // Check that the original semantic result conforms to `Differentiable`.
4632- auto valueResultConf = getDifferentiableConformance (
4633- originalResultType, derivative->getDeclContext ());
4634- if (!valueResultConf) {
4635- diags.diagnose (attr->getLocation (),
4636- diag::derivative_attr_result_value_not_differentiable,
4637- valueResultElt.getType ());
4583+ // Compute the expected differential/pullback type.
4584+ auto expectedLinearMapTypeOrError =
4585+ originalFnType->getAutoDiffDerivativeFunctionLinearMapType (
4586+ resolvedDiffParamIndices, kind.getLinearMapKind (), lookupConformance,
4587+ /* makeSelfParamFirst*/ true );
4588+
4589+ // Helper for diagnosing derivative function type errors.
4590+ auto errorHandler = [&](const DerivativeFunctionTypeError &error) {
4591+ switch (error.kind ) {
4592+ case DerivativeFunctionTypeError::Kind::NoSemanticResults:
4593+ diags
4594+ .diagnose (attr->getLocation (),
4595+ diag::autodiff_attr_original_multiple_semantic_results)
4596+ .highlight (attr->getOriginalFunctionName ().Loc .getSourceRange ());
4597+ attr->setInvalid ();
4598+ return ;
4599+ case DerivativeFunctionTypeError::Kind::MultipleSemanticResults:
4600+ diags
4601+ .diagnose (attr->getLocation (),
4602+ diag::autodiff_attr_original_multiple_semantic_results)
4603+ .highlight (attr->getOriginalFunctionName ().Loc .getSourceRange ());
4604+ attr->setInvalid ();
4605+ return ;
4606+ case DerivativeFunctionTypeError::Kind::NonDifferentiableParameters: {
4607+ auto *nonDiffParamIndices = error.getNonDifferentiableParameterIndices ();
4608+ SmallVector<AnyFunctionType::Param, 4 > diffParams;
4609+ error.functionType ->getSubsetParameters (resolvedDiffParamIndices,
4610+ diffParams);
4611+ for (unsigned i : range (diffParams.size ())) {
4612+ if (!nonDiffParamIndices->contains (i))
4613+ continue ;
4614+ SourceLoc loc = parsedDiffParams.empty () ? attr->getLocation ()
4615+ : parsedDiffParams[i].getLoc ();
4616+ auto diffParamType = diffParams[i].getPlainType ();
4617+ diags.diagnose (loc, diag::diff_params_clause_param_not_differentiable,
4618+ diffParamType);
4619+ }
4620+ return ;
4621+ }
4622+ case DerivativeFunctionTypeError::Kind::NonDifferentiableResult:
4623+ auto originalResultType = error.getNonDifferentiableResultType ();
4624+ diags.diagnose (attr->getLocation (),
4625+ diag::differentiable_attr_result_not_differentiable,
4626+ originalResultType);
4627+ attr->setInvalid ();
4628+ return ;
4629+ }
4630+ };
4631+ // Diagnose any derivative function type errors.
4632+ if (!expectedLinearMapTypeOrError) {
4633+ auto error = expectedLinearMapTypeOrError.takeError ();
4634+ handleAllErrors (std::move (error), errorHandler);
46384635 return true ;
46394636 }
4640-
4641- // Compute the actual differential/pullback type that we use for comparison
4642- // with the expected type. We must canonicalize the derivative interface type
4643- // before extracting the differential/pullback type from it, so that the
4644- // derivative interface type generic signature is available for simplifying
4645- // types.
4637+ Type expectedLinearMapType = expectedLinearMapTypeOrError.get ();
4638+ if (expectedLinearMapType->hasTypeParameter ())
4639+ expectedLinearMapType =
4640+ derivative->mapTypeIntoContext (expectedLinearMapType);
4641+ if (expectedLinearMapType->hasArchetype ())
4642+ expectedLinearMapType = expectedLinearMapType->mapTypeOutOfContext ();
4643+
4644+ // Compute the actual differential/pullback type for comparison with the
4645+ // expected type. We must canonicalize the derivative interface type before
4646+ // extracting the differential/pullback type from it so that types are
4647+ // simplified via the canonical generic signature.
46464648 CanType canActualResultType = derivativeInterfaceType->getCanonicalType ();
46474649 while (isa<AnyFunctionType>(canActualResultType)) {
46484650 canActualResultType =
46494651 cast<AnyFunctionType>(canActualResultType).getResult ();
46504652 }
4651- CanType actualFuncEltType =
4653+ CanType actualLinearMapType =
46524654 cast<TupleType>(canActualResultType).getElementType (1 );
46534655
4654- // Compute expected differential/pullback type.
4655- Type expectedFuncEltType =
4656- originalFnType->getAutoDiffDerivativeFunctionLinearMapType (
4657- resolvedDiffParamIndices, kind.getLinearMapKind (), lookupConformance,
4658- /* makeSelfParamFirst*/ true );
4659- if (expectedFuncEltType->hasTypeParameter ())
4660- expectedFuncEltType = derivative->mapTypeIntoContext (expectedFuncEltType);
4661- if (expectedFuncEltType->hasArchetype ())
4662- expectedFuncEltType = expectedFuncEltType->mapTypeOutOfContext ();
4663-
46644656 // Check if differential/pullback type matches expected type.
4665- if (!actualFuncEltType ->isEqual (expectedFuncEltType )) {
4657+ if (!actualLinearMapType ->isEqual (expectedLinearMapType )) {
46664658 // Emit differential/pullback type mismatch error on attribute.
46674659 diags.diagnose (attr->getLocation (),
46684660 diag::derivative_attr_result_func_type_mismatch,
@@ -4675,7 +4667,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
46754667 diags
46764668 .diagnose (funcEltTypeRepr->getStartLoc (),
46774669 diag::derivative_attr_result_func_type_mismatch_note,
4678- funcResultElt.getName (), expectedFuncEltType )
4670+ funcResultElt.getName (), expectedLinearMapType )
46794671 .highlight (funcEltTypeRepr->getSourceRange ());
46804672 // Emit note showing original function location, if possible.
46814673 if (originalAFD->getLoc ().isValid ())
0 commit comments