diff --git a/spring-core/src/main/java/org/springframework/core/convert/support/GenericConversionService.java b/spring-core/src/main/java/org/springframework/core/convert/support/GenericConversionService.java index 525b9f250d91..77f4b9135858 100644 --- a/spring-core/src/main/java/org/springframework/core/convert/support/GenericConversionService.java +++ b/spring-core/src/main/java/org/springframework/core/convert/support/GenericConversionService.java @@ -17,9 +17,10 @@ package org.springframework.core.convert.support; import java.lang.reflect.Array; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; -import java.util.HashSet; +import java.util.Deque; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.LinkedList; @@ -518,8 +519,8 @@ public void remove(Class sourceType, Class targetType) { */ public GenericConverter find(TypeDescriptor sourceType, TypeDescriptor targetType) { // Search the full type hierarchy - List> sourceCandidates = getClassHierarchy(sourceType.getType()); - List> targetCandidates = getClassHierarchy(targetType.getType()); + Iterable> sourceCandidates = getClassHierarchy(sourceType.getType()); + Iterable> targetCandidates = getClassHierarchy(targetType.getType()); for (Class sourceCandidate : sourceCandidates) { for (Class targetCandidate : targetCandidates) { ConvertiblePair convertiblePair = new ConvertiblePair(sourceCandidate, targetCandidate); @@ -555,41 +556,49 @@ private GenericConverter getRegisteredConverter(TypeDescriptor sourceType, /** * Returns an ordered class hierarchy for the given type. * @param type the type - * @return an ordered list of all classes that the given type extends or implements + * @return an Iterable of all classes that the given type extends or implements */ - private List> getClassHierarchy(Class type) { - List> hierarchy = new ArrayList>(20); - Set> visited = new HashSet>(20); - addToClassHierarchy(0, ClassUtils.resolvePrimitiveIfNecessary(type), false, hierarchy, visited); + private Iterable> getClassHierarchy(Class type) { + Deque> classStack = new ArrayDeque>(20); + LinkedHashSet> hierarchy = new LinkedHashSet>(20); boolean array = type.isArray(); - int i = 0; - while (i < hierarchy.size()) { - Class candidate = hierarchy.get(i); - candidate = (array ? candidate.getComponentType() : ClassUtils.resolvePrimitiveIfNecessary(candidate)); + + classStack.push(ClassUtils.resolvePrimitiveIfNecessary(type)); + + Class candidate = null; + while((candidate = classStack.pollFirst()) != null) { + candidate = ClassUtils.resolvePrimitiveIfNecessary(candidate); + hierarchy.add(!candidate.isArray() && array ? Array.newInstance(candidate, 0).getClass() : candidate); + Class superclass = candidate.getSuperclass(); - if (candidate.getSuperclass() != null && superclass != Object.class) { - addToClassHierarchy(i + 1, candidate.getSuperclass(), array, hierarchy, visited); + if (superclass != null && superclass != Object.class && superclass != Enum.class) { + classStack.push(superclass); } - for (Class implementedInterface : candidate.getInterfaces()) { - addToClassHierarchy(hierarchy.size(), implementedInterface, array, hierarchy, visited); + + for(Class implementedInterface : candidate.getInterfaces()) { + // add interfaces to the other end of the queue, so that + // concrete classes always come first in the hierarchy + classStack.addLast(implementedInterface); } - i++; } - addToClassHierarchy(hierarchy.size(), Object.class, array, hierarchy, visited); - addToClassHierarchy(hierarchy.size(), Object.class, false, hierarchy, visited); - return hierarchy; - } - - private void addToClassHierarchy(int index, Class type, boolean asArray, - List> hierarchy, Set> visited) { - if (asArray) { - type = Array.newInstance(type, 0).getClass(); + + // make sure Enum comes at the "end" of the hierarchy (if necessary) + if(type.isEnum()) { + if(array) { + hierarchy.add(Array.newInstance(Enum.class, 0).getClass()); + } + hierarchy.add(Enum.class); } - if (visited.add(type)) { - hierarchy.add(index, type); + + // always add Object to hierarchy + if(array) { + hierarchy.add(Array.newInstance(Object.class, 0).getClass()); } + hierarchy.add(Object.class); + + return hierarchy; } - + @Override public String toString() { StringBuilder builder = new StringBuilder(); diff --git a/spring-core/src/test/java/org/springframework/core/convert/support/GenericConversionServiceTests.java b/spring-core/src/test/java/org/springframework/core/convert/support/GenericConversionServiceTests.java index a9f22acff2dc..0d0a257c3bd2 100644 --- a/spring-core/src/test/java/org/springframework/core/convert/support/GenericConversionServiceTests.java +++ b/spring-core/src/test/java/org/springframework/core/convert/support/GenericConversionServiceTests.java @@ -758,6 +758,21 @@ public void testEnumWithInterfaceToStringConversion() { assertEquals("1", result); } + @Test + public void testStringToEnumWithInterfaceConversion() { + conversionService.addConverterFactory(new StringToEnumConverterFactory()); + conversionService.addConverterFactory(new StringToMyEnumInterfaceConverterFactory()); + assertEquals(MyEnum.A, conversionService.convert("1", MyEnum.class)); + } + + @Test + public void testStringToEnumWithBaseInterfaceConversion() { + conversionService.addConverterFactory(new StringToEnumConverterFactory()); + conversionService.addConverterFactory(new StringToMyEnumBaseInterfaceConverterFactory()); + assertEquals(MyEnum.A, conversionService.convert("base1", MyEnum.class)); + } + + @Test public void convertNullAnnotatedStringToString() throws Exception { DefaultConversionService.addDefaultConverters(conversionService); @@ -930,20 +945,35 @@ public int getNestedMatchAttempts() { } } - - interface MyEnumInterface { - + interface MyEnumBaseInterface { + String getBaseCode(); + } + + interface MyEnumInterface extends MyEnumBaseInterface { String getCode(); } public static enum MyEnum implements MyEnumInterface { - A { - @Override - public String getCode() { - return "1"; - } - } + A("1"), + B("2"), + C("3"); + + private String code; + + MyEnum(String code) { + this.code = code; + } + + @Override + public String getCode() { + return code; + } + + @Override + public String getBaseCode() { + return "base" + code; + } } @@ -970,6 +1000,59 @@ public String convert(T source) { return source.getCode(); } } + + private static class StringToMyEnumInterfaceConverterFactory implements ConverterFactory { + + @SuppressWarnings("unchecked") + public Converter getConverter(Class targetType) { + return new StringToMyEnumInterfaceConverter(targetType); + } + + private static class StringToMyEnumInterfaceConverter & MyEnumInterface> implements Converter { + private final Class enumType; + + public StringToMyEnumInterfaceConverter(Class enumType) { + this.enumType = enumType; + } + + public T convert(String source) { + for (T value : enumType.getEnumConstants()) { + if (value.getCode().equals(source)) { + return value; + } + } + return null; + } + } + + } + + private static class StringToMyEnumBaseInterfaceConverterFactory implements ConverterFactory { + + @SuppressWarnings("unchecked") + public Converter getConverter(Class targetType) { + return new StringToMyEnumBaseInterfaceConverter(targetType); + } + + private static class StringToMyEnumBaseInterfaceConverter & MyEnumBaseInterface> implements Converter { + private final Class enumType; + + public StringToMyEnumBaseInterfaceConverter(Class enumType) { + this.enumType = enumType; + } + + public T convert(String source) { + for (T value : enumType.getEnumConstants()) { + if (value.getBaseCode().equals(source)) { + return value; + } + } + return null; + } + } + + } + public static class MyStringToStringCollectionConverter implements Converter> {