|
43 | 43 | import org.elasticsearch.common.util.concurrent.EsExecutors; |
44 | 44 | import org.elasticsearch.common.util.concurrent.ThreadContext; |
45 | 45 | import org.elasticsearch.common.util.set.Sets; |
| 46 | +import org.elasticsearch.core.Nullable; |
| 47 | +import org.elasticsearch.xcontent.NamedXContentRegistry; |
| 48 | +import org.elasticsearch.xcontent.XContentBuilder; |
46 | 49 | import org.elasticsearch.env.Environment; |
47 | 50 | import org.elasticsearch.env.NodeEnvironment; |
48 | 51 | import org.elasticsearch.http.HttpServerTransport; |
@@ -567,7 +570,7 @@ Collection<Object> createComponents(Client client, ThreadPool threadPool, Cluste |
567 | 570 | extensionComponents |
568 | 571 | ); |
569 | 572 | if (providers != null && providers.isEmpty() == false) { |
570 | | - customRoleProviders.put(extension.toString(), providers); |
| 573 | + customRoleProviders.put(extension.extensionName(), providers); |
571 | 574 | } |
572 | 575 | } |
573 | 576 |
|
@@ -676,37 +679,15 @@ auditTrailService, failureHandler, threadPool, anonymousUser, getAuthorizationEn |
676 | 679 | } |
677 | 680 |
|
678 | 681 | private AuthorizationEngine getAuthorizationEngine() { |
679 | | - AuthorizationEngine authorizationEngine = null; |
680 | | - String extensionName = null; |
681 | | - for (SecurityExtension extension : securityExtensions) { |
682 | | - final AuthorizationEngine extensionEngine = extension.getAuthorizationEngine(settings); |
683 | | - if (extensionEngine != null && authorizationEngine != null) { |
684 | | - throw new IllegalStateException("Extensions [" + extensionName + "] and [" + extension.toString() + "] " |
685 | | - + "both set an authorization engine"); |
686 | | - } |
687 | | - authorizationEngine = extensionEngine; |
688 | | - extensionName = extension.toString(); |
689 | | - } |
690 | | - |
691 | | - if (authorizationEngine != null) { |
692 | | - logger.debug("Using authorization engine from extension [" + extensionName + "]"); |
693 | | - } |
694 | | - return authorizationEngine; |
| 682 | + return findValueFromExtensions("authorization engine", extension -> extension.getAuthorizationEngine(settings)); |
695 | 683 | } |
696 | 684 |
|
697 | 685 | private AuthenticationFailureHandler createAuthenticationFailureHandler(final Realms realms, |
698 | 686 | final SecurityExtension.SecurityComponents components) { |
699 | | - AuthenticationFailureHandler failureHandler = null; |
700 | | - String extensionName = null; |
701 | | - for (SecurityExtension extension : securityExtensions) { |
702 | | - AuthenticationFailureHandler extensionFailureHandler = extension.getAuthenticationFailureHandler(components); |
703 | | - if (extensionFailureHandler != null && failureHandler != null) { |
704 | | - throw new IllegalStateException("Extensions [" + extensionName + "] and [" + extension.toString() + "] " |
705 | | - + "both set an authentication failure handler"); |
706 | | - } |
707 | | - failureHandler = extensionFailureHandler; |
708 | | - extensionName = extension.toString(); |
709 | | - } |
| 687 | + AuthenticationFailureHandler failureHandler = findValueFromExtensions( |
| 688 | + "authentication failure handler", |
| 689 | + extension -> extension.getAuthenticationFailureHandler(components) |
| 690 | + ); |
710 | 691 | if (failureHandler == null) { |
711 | 692 | logger.debug("Using default authentication failure handler"); |
712 | 693 | Supplier<Map<String, List<String>>> headersSupplier = () -> { |
@@ -743,12 +724,48 @@ private AuthenticationFailureHandler createAuthenticationFailureHandler(final Re |
743 | 724 | getLicenseState().addListener(() -> { |
744 | 725 | finalDefaultFailureHandler.setHeaders(headersSupplier.get()); |
745 | 726 | }); |
746 | | - } else { |
747 | | - logger.debug("Using authentication failure handler from extension [" + extensionName + "]"); |
748 | 727 | } |
749 | 728 | return failureHandler; |
750 | 729 | } |
751 | 730 |
|
| 731 | + /** |
| 732 | + * Calls the provided function for each configured extension and return the value that was generated by the extensions. |
| 733 | + * If multiple extensions provide a value, throws {@link IllegalStateException}. |
| 734 | + * If no extensions provide a value (or if there are no extensions) returns {@code null}. |
| 735 | + */ |
| 736 | + @Nullable |
| 737 | + private <T> T findValueFromExtensions(String valueType, Function<SecurityExtension, T> method) { |
| 738 | + T foundValue = null; |
| 739 | + String fromExtension = null; |
| 740 | + for (SecurityExtension extension : securityExtensions) { |
| 741 | + final T extensionValue = method.apply(extension); |
| 742 | + if (extensionValue == null) { |
| 743 | + continue; |
| 744 | + } |
| 745 | + if (foundValue == null) { |
| 746 | + foundValue = extensionValue; |
| 747 | + fromExtension = extension.extensionName(); |
| 748 | + } else { |
| 749 | + throw new IllegalStateException( |
| 750 | + "Extensions [" |
| 751 | + + fromExtension |
| 752 | + + "] and [" |
| 753 | + + extension.extensionName() |
| 754 | + + "] " |
| 755 | + + " both attempted to provide a value for [" |
| 756 | + + valueType |
| 757 | + + "]" |
| 758 | + ); |
| 759 | + } |
| 760 | + } |
| 761 | + if (foundValue == null) { |
| 762 | + return null; |
| 763 | + } else { |
| 764 | + logger.debug("Using [{}] [{}] from extension [{}]", valueType, foundValue, fromExtension); |
| 765 | + return foundValue; |
| 766 | + } |
| 767 | + } |
| 768 | + |
752 | 769 | @Override |
753 | 770 | public Settings additionalSettings() { |
754 | 771 | return additionalSettings(settings, enabled); |
|
0 commit comments