diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/AbstractDependsOnBeanFactoryPostProcessor.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/AbstractDependsOnBeanFactoryPostProcessor.java index 60a77897012a..5e3a47031c51 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/AbstractDependsOnBeanFactoryPostProcessor.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/AbstractDependsOnBeanFactoryPostProcessor.java @@ -17,8 +17,9 @@ package org.springframework.boot.autoconfigure; import java.util.Arrays; -import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Set; +import java.util.stream.Collectors; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryUtils; @@ -39,6 +40,7 @@ * @author Dave Syer * @author Phillip Webb * @author Andy Wilkinson + * @author Dmytro Nosan * @since 1.3.0 * @see BeanDefinition#setDependsOn(String[]) */ @@ -48,7 +50,7 @@ public abstract class AbstractDependsOnBeanFactoryPostProcessor implements BeanF private final Class> factoryBeanClass; - private final String[] dependsOn; + private final Object[] dependsOn; protected AbstractDependsOnBeanFactoryPostProcessor(Class beanClass, Class> factoryBeanClass, String... dependsOn) { @@ -57,6 +59,20 @@ protected AbstractDependsOnBeanFactoryPostProcessor(Class beanClass, this.dependsOn = dependsOn; } + /** + * Create an instance with target bean and factory bean classes and dependency types. + * @param beanClass target bean class + * @param factoryBeanClass target factory bean class + * @param dependsOn dependency types + * @since 2.2.0 + */ + protected AbstractDependsOnBeanFactoryPostProcessor(Class beanClass, + Class> factoryBeanClass, Class... dependsOn) { + this.beanClass = beanClass; + this.factoryBeanClass = factoryBeanClass; + this.dependsOn = dependsOn; + } + /** * Create an instance with target bean class and dependencies. * @param beanClass target bean class @@ -67,31 +83,51 @@ protected AbstractDependsOnBeanFactoryPostProcessor(Class beanClass, String.. this(beanClass, null, dependsOn); } + /** + * Create an instance with target bean class and dependency types. + * @param beanClass target bean class + * @param dependsOn dependency types + * @since 2.2.0 + */ + protected AbstractDependsOnBeanFactoryPostProcessor(Class beanClass, Class... dependsOn) { + this(beanClass, null, dependsOn); + } + @Override public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) { for (String beanName : getBeanNames(beanFactory)) { BeanDefinition definition = getBeanDefinition(beanName, beanFactory); String[] dependencies = definition.getDependsOn(); - for (String bean : this.dependsOn) { + for (String bean : getDependsOn(beanFactory)) { dependencies = StringUtils.addStringToArray(dependencies, bean); } definition.setDependsOn(dependencies); } } - private Iterable getBeanNames(ListableBeanFactory beanFactory) { - Set names = new HashSet<>(); - names.addAll(Arrays - .asList(BeanFactoryUtils.beanNamesForTypeIncludingAncestors(beanFactory, this.beanClass, true, false))); + private Set getDependsOn(ListableBeanFactory beanFactory) { + if (this.dependsOn instanceof Class[]) { + return Arrays.stream(((Class[]) this.dependsOn)) + .flatMap((beanClass) -> getBeanNames(beanFactory, beanClass).stream()) + .collect(Collectors.toCollection(LinkedHashSet::new)); + } + return Arrays.stream(this.dependsOn).map(String::valueOf).collect(Collectors.toCollection(LinkedHashSet::new)); + } + + private Set getBeanNames(ListableBeanFactory beanFactory) { + Set names = getBeanNames(beanFactory, this.beanClass); if (this.factoryBeanClass != null) { - for (String factoryBeanName : BeanFactoryUtils.beanNamesForTypeIncludingAncestors(beanFactory, - this.factoryBeanClass, true, false)) { - names.add(BeanFactoryUtils.transformedBeanName(factoryBeanName)); - } + names.addAll(getBeanNames(beanFactory, this.factoryBeanClass)); } return names; } + private static Set getBeanNames(ListableBeanFactory beanFactory, Class beanClass) { + String[] names = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(beanFactory, beanClass, true, false); + return Arrays.stream(names).map(BeanFactoryUtils::transformedBeanName) + .collect(Collectors.toCollection(LinkedHashSet::new)); + } + private static BeanDefinition getBeanDefinition(String beanName, ConfigurableListableBeanFactory beanFactory) { try { return beanFactory.getBeanDefinition(beanName); diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/AbstractDependsOnBeanFactoryPostProcessorTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/AbstractDependsOnBeanFactoryPostProcessorTests.java new file mode 100644 index 000000000000..660029efb5d2 --- /dev/null +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/AbstractDependsOnBeanFactoryPostProcessorTests.java @@ -0,0 +1,178 @@ +/* + * Copyright 2012-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.autoconfigure; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.boot.test.context.assertj.AssertableApplicationContext; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link AbstractDependsOnBeanFactoryPostProcessor}. + * + * @author Dmytro Nosan + */ +class AbstractDependsOnBeanFactoryPostProcessorTests { + + private ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(FooBarConfiguration.class); + + @Test + void fooBeansShouldDependOnBarBeanNames() { + this.contextRunner + .withUserConfiguration(FooDependsOnBarNamePostProcessor.class, FooBarFactoryBeanConfiguration.class) + .run(this::assertThatFooDependsOnBar); + } + + @Test + void fooBeansShouldDependOnBarBeanTypes() { + this.contextRunner + .withUserConfiguration(FooDependsOnBarTypePostProcessor.class, FooBarFactoryBeanConfiguration.class) + .run(this::assertThatFooDependsOnBar); + } + + @Test + void fooBeansShouldDependOnBarBeanNamesParentContext() { + try (AnnotationConfigApplicationContext parentContext = new AnnotationConfigApplicationContext( + FooBarFactoryBeanConfiguration.class)) { + this.contextRunner.withUserConfiguration(FooDependsOnBarNamePostProcessor.class).withParent(parentContext) + .run(this::assertThatFooDependsOnBar); + } + } + + @Test + void fooBeansShouldDependOnBarBeanTypesParentContext() { + try (AnnotationConfigApplicationContext parentContext = new AnnotationConfigApplicationContext( + FooBarFactoryBeanConfiguration.class)) { + this.contextRunner.withUserConfiguration(FooDependsOnBarTypePostProcessor.class).withParent(parentContext) + .run(this::assertThatFooDependsOnBar); + } + } + + private void assertThatFooDependsOnBar(AssertableApplicationContext context) { + ConfigurableListableBeanFactory beanFactory = context.getBeanFactory(); + assertThat(getBeanDefinition("foo", beanFactory).getDependsOn()).containsExactly("bar", "barFactoryBean"); + assertThat(getBeanDefinition("fooFactoryBean", beanFactory).getDependsOn()).containsExactly("bar", + "barFactoryBean"); + } + + private BeanDefinition getBeanDefinition(String beanName, ConfigurableListableBeanFactory beanFactory) { + try { + return beanFactory.getBeanDefinition(beanName); + } + catch (NoSuchBeanDefinitionException ex) { + BeanFactory parentBeanFactory = beanFactory.getParentBeanFactory(); + if (parentBeanFactory instanceof ConfigurableListableBeanFactory) { + return getBeanDefinition(beanName, (ConfigurableListableBeanFactory) parentBeanFactory); + } + throw ex; + } + } + + static class Foo { + + } + + static class Bar { + + } + + @Configuration(proxyBeanMethods = false) + static class FooBarFactoryBeanConfiguration { + + @Bean + public FooFactoryBean fooFactoryBean() { + return new FooFactoryBean(); + } + + @Bean + public BarFactoryBean barFactoryBean() { + return new BarFactoryBean(); + } + + } + + @Configuration(proxyBeanMethods = false) + static class FooBarConfiguration { + + @Bean + public Bar bar() { + return new Bar(); + } + + @Bean + public Foo foo() { + return new Foo(); + } + + } + + static class FooDependsOnBarTypePostProcessor extends AbstractDependsOnBeanFactoryPostProcessor { + + protected FooDependsOnBarTypePostProcessor() { + super(Foo.class, FooFactoryBean.class, Bar.class, BarFactoryBean.class); + } + + } + + static class FooDependsOnBarNamePostProcessor extends AbstractDependsOnBeanFactoryPostProcessor { + + protected FooDependsOnBarNamePostProcessor() { + super(Foo.class, FooFactoryBean.class, "bar", "barFactoryBean"); + } + + } + + static class FooFactoryBean implements FactoryBean { + + @Override + public Foo getObject() { + return new Foo(); + } + + @Override + public Class getObjectType() { + return Foo.class; + } + + } + + static class BarFactoryBean implements FactoryBean { + + @Override + public Bar getObject() { + return new Bar(); + } + + @Override + public Class getObjectType() { + return Bar.class; + } + + } + +}