diff --git a/compiler/src/org.graalvm.compiler.java/src/org/graalvm/compiler/java/LambdaUtils.java b/compiler/src/org.graalvm.compiler.java/src/org/graalvm/compiler/java/LambdaUtils.java index 1ef6b965124d..5d013f7082c3 100644 --- a/compiler/src/org.graalvm.compiler.java/src/org/graalvm/compiler/java/LambdaUtils.java +++ b/compiler/src/org.graalvm.compiler.java/src/org/graalvm/compiler/java/LambdaUtils.java @@ -51,6 +51,8 @@ public final class LambdaUtils { private static final Pattern LAMBDA_PATTERN = Pattern.compile("\\$\\$Lambda\\$\\d+[/\\.][^/]+;"); private static final char[] HEX = "0123456789abcdef".toCharArray(); + public static final String LAMBDA_SPLIT_PATTERN = "\\$\\$Lambda\\$"; + public static final String LAMBDA_CLASS_NAME_SUBSTRING = "$$Lambda$"; private static GraphBuilderConfiguration buildLambdaParserConfig(ClassInitializationPlugin cip) { GraphBuilderConfiguration.Plugins plugins = new GraphBuilderConfiguration.Plugins(new InvocationPlugins()); @@ -107,7 +109,7 @@ public static String findStableLambdaName(ClassInitializationPlugin cip, Provide public static boolean isLambdaType(ResolvedJavaType type) { String typeName = type.getName(); - return type.isFinalFlagSet() && typeName.contains("/") && typeName.contains("$$Lambda$") && lambdaMatcher(type.getName()).find(); + return type.isFinalFlagSet() && typeName.contains("/") && typeName.contains(LAMBDA_CLASS_NAME_SUBSTRING) && lambdaMatcher(type.getName()).find(); } private static String createStableLambdaName(ResolvedJavaType lambdaType, List targetMethods) { diff --git a/sdk/src/org.graalvm.nativeimage/src/org/graalvm/nativeimage/impl/RuntimeSerializationSupport.java b/sdk/src/org.graalvm.nativeimage/src/org/graalvm/nativeimage/impl/RuntimeSerializationSupport.java index b19a6016459a..ee18235f09be 100644 --- a/sdk/src/org.graalvm.nativeimage/src/org/graalvm/nativeimage/impl/RuntimeSerializationSupport.java +++ b/sdk/src/org.graalvm.nativeimage/src/org/graalvm/nativeimage/impl/RuntimeSerializationSupport.java @@ -50,4 +50,5 @@ public interface RuntimeSerializationSupport { void registerWithTargetConstructorClass(ConfigurationCondition condition, String className, String customTargetConstructorClassName); + void registerLambdaCapturingClass(ConfigurationCondition condition, String lambdaCapturingClassName); } diff --git a/substratevm/src/com.oracle.graal.pointsto/src/com/oracle/graal/pointsto/reports/CallTreePrinter.java b/substratevm/src/com.oracle.graal.pointsto/src/com/oracle/graal/pointsto/reports/CallTreePrinter.java index 854c982d0217..43e0ced1d4b6 100644 --- a/substratevm/src/com.oracle.graal.pointsto/src/com/oracle/graal/pointsto/reports/CallTreePrinter.java +++ b/substratevm/src/com.oracle.graal.pointsto/src/com/oracle/graal/pointsto/reports/CallTreePrinter.java @@ -66,6 +66,7 @@ import jdk.vm.ci.meta.JavaKind; import jdk.vm.ci.meta.ResolvedJavaMethod; import jdk.vm.ci.meta.ResolvedJavaType; +import org.graalvm.compiler.java.LambdaUtils; public final class CallTreePrinter { @@ -311,7 +312,7 @@ public Set classesSet(boolean packageNameOnly) { String name = method.getDeclaringClass().toJavaName(true); if (packageNameOnly) { name = packagePrefix(name); - if (name.contains("$$Lambda$")) { + if (name.contains(LambdaUtils.LAMBDA_CLASS_NAME_SUBSTRING)) { /* Also strip synthetic package names added for lambdas. */ name = packagePrefix(name); } diff --git a/substratevm/src/com.oracle.svm.agent/src/com/oracle/svm/agent/BreakpointInterceptor.java b/substratevm/src/com.oracle.svm.agent/src/com/oracle/svm/agent/BreakpointInterceptor.java index d8040150826c..c4ec77eda343 100644 --- a/substratevm/src/com.oracle.svm.agent/src/com/oracle/svm/agent/BreakpointInterceptor.java +++ b/substratevm/src/com.oracle.svm.agent/src/com/oracle/svm/agent/BreakpointInterceptor.java @@ -146,10 +146,20 @@ final class BreakpointInterceptor { /* Classes from these class loaders are assumed to not be dynamically loaded. */ private static JNIObjectHandle[] builtinClassLoaders; - private static void traceBreakpoint(JNIEnvironment env, JNIObjectHandle clazz, JNIObjectHandle declaringClass, JNIObjectHandle callerClass, String function, Object result, + private static void traceReflectBreakpoint(JNIEnvironment env, JNIObjectHandle clazz, JNIObjectHandle declaringClass, JNIObjectHandle callerClass, String function, Object result, JNIMethodId[] stackTrace, Object... args) { + traceBreakpoint(env, "reflect", clazz, declaringClass, callerClass, function, result, stackTrace, args); + } + + private static void traceSerializeBreakpoint(JNIEnvironment env, String function, Object result, + JNIMethodId[] stackTrace, Object... args) { + traceBreakpoint(env, "serialization", nullHandle(), nullHandle(), nullHandle(), function, result, stackTrace, args); + } + + private static void traceBreakpoint(JNIEnvironment env, String context, JNIObjectHandle clazz, JNIObjectHandle declaringClass, JNIObjectHandle callerClass, String function, Object result, + JNIMethodId[] stackTrace, Object[] args) { if (tracer != null) { - tracer.traceCall("reflect", + tracer.traceCall(context, function, getClassNameOr(env, clazz, null, Tracer.UNKNOWN_VALUE), getClassNameOr(env, declaringClass, null, Tracer.UNKNOWN_VALUE), @@ -204,7 +214,7 @@ private static boolean forName(JNIEnvironment jni, Breakpoint bp, InterceptedSta result = loadedClass.notEqual(nullHandle()); } - traceBreakpoint(jni, bp.clazz, nullHandle(), callerClass, bp.specification.methodName, result, state.getFullStackTraceOrNull(), className); + traceReflectBreakpoint(jni, bp.clazz, nullHandle(), callerClass, bp.specification.methodName, result, state.getFullStackTraceOrNull(), className); return true; } @@ -219,7 +229,7 @@ private static boolean getDeclaredFields(JNIEnvironment jni, Breakpoint bp, Inte private static boolean handleGetFields(JNIEnvironment jni, Breakpoint bp, InterceptedState state) { JNIObjectHandle callerClass = state.getDirectCallerClass(); JNIObjectHandle self = getObjectArgument(0); - traceBreakpoint(jni, getClassOrSingleProxyInterface(jni, self), nullHandle(), callerClass, bp.specification.methodName, null, state.getFullStackTraceOrNull()); + traceReflectBreakpoint(jni, getClassOrSingleProxyInterface(jni, self), nullHandle(), callerClass, bp.specification.methodName, null, state.getFullStackTraceOrNull()); return true; } @@ -242,7 +252,7 @@ private static boolean getDeclaredConstructors(JNIEnvironment jni, Breakpoint bp private static boolean handleGetMethods(JNIEnvironment jni, Breakpoint bp, InterceptedState state) { JNIObjectHandle callerClass = state.getDirectCallerClass(); JNIObjectHandle self = getObjectArgument(0); - traceBreakpoint(jni, getClassOrSingleProxyInterface(jni, self), nullHandle(), callerClass, bp.specification.methodName, null, state.getFullStackTraceOrNull()); + traceReflectBreakpoint(jni, getClassOrSingleProxyInterface(jni, self), nullHandle(), callerClass, bp.specification.methodName, null, state.getFullStackTraceOrNull()); return true; } @@ -261,7 +271,7 @@ private static boolean getPermittedSubclasses(JNIEnvironment jni, Breakpoint bp, private static boolean handleGetClasses(JNIEnvironment jni, Breakpoint bp, InterceptedState state) { JNIObjectHandle callerClass = state.getDirectCallerClass(); JNIObjectHandle self = getObjectArgument(0); - traceBreakpoint(jni, self, nullHandle(), callerClass, bp.specification.methodName, null, state.getFullStackTraceOrNull()); + traceReflectBreakpoint(jni, self, nullHandle(), callerClass, bp.specification.methodName, null, state.getFullStackTraceOrNull()); return true; } @@ -288,7 +298,7 @@ private static boolean handleGetField(JNIEnvironment jni, Breakpoint bp, boolean declaring = nullHandle(); } } - traceBreakpoint(jni, getClassOrSingleProxyInterface(jni, self), getClassOrSingleProxyInterface(jni, declaring), callerClass, bp.specification.methodName, result.notEqual(nullHandle()), + traceReflectBreakpoint(jni, getClassOrSingleProxyInterface(jni, self), getClassOrSingleProxyInterface(jni, declaring), callerClass, bp.specification.methodName, result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), fromJniString(jni, name)); return true; } @@ -302,7 +312,7 @@ private static boolean objectFieldOffsetByName(JNIEnvironment jni, Breakpoint bp boolean validResult = !clearException(jni); JNIObjectHandle clazz = getMethodDeclaringClass(bp.method); - traceBreakpoint(jni, clazz, declaring, callerClass, "objectFieldOffset", validResult, state.getFullStackTraceOrNull(), fromJniString(jni, name)); + traceReflectBreakpoint(jni, clazz, declaring, callerClass, "objectFieldOffset", validResult, state.getFullStackTraceOrNull(), fromJniString(jni, name)); return true; } @@ -315,7 +325,7 @@ private static boolean getConstructor(JNIEnvironment jni, Breakpoint bp, Interce result = nullHandle(); } Object paramTypes = getClassArrayNames(jni, paramTypesHandle); - traceBreakpoint(jni, getClassOrSingleProxyInterface(jni, self), nullHandle(), callerClass, bp.specification.methodName, nullHandle().notEqual(result), state.getFullStackTraceOrNull(), + traceReflectBreakpoint(jni, getClassOrSingleProxyInterface(jni, self), nullHandle(), callerClass, bp.specification.methodName, nullHandle().notEqual(result), state.getFullStackTraceOrNull(), paramTypes); return true; } @@ -346,7 +356,7 @@ private static boolean handleGetMethod(JNIEnvironment jni, Breakpoint bp, boolea } String name = fromJniString(jni, nameHandle); Object paramTypes = getClassArrayNames(jni, paramTypesHandle); - traceBreakpoint(jni, getClassOrSingleProxyInterface(jni, self), getClassOrSingleProxyInterface(jni, declaring), callerClass, bp.specification.methodName, result.notEqual(nullHandle()), + traceReflectBreakpoint(jni, getClassOrSingleProxyInterface(jni, self), getClassOrSingleProxyInterface(jni, declaring), callerClass, bp.specification.methodName, result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), name, paramTypes); return true; } @@ -379,7 +389,7 @@ private static boolean getEnclosingMethod(JNIEnvironment jni, Breakpoint bp, Int } } } - traceBreakpoint(jni, nullHandle(), nullHandle(), callerClass, bp.specification.methodName, enclosing.notEqual(nullHandle()) ? result : false, state.getFullStackTraceOrNull()); + traceReflectBreakpoint(jni, nullHandle(), nullHandle(), callerClass, bp.specification.methodName, enclosing.notEqual(nullHandle()) ? result : false, state.getFullStackTraceOrNull()); return true; } @@ -411,10 +421,8 @@ private static boolean handleInvokeMethod(JNIEnvironment jni, @SuppressWarnings( paramTypesHandle = nullHandle(); } Object paramTypes = getClassArrayNames(jni, paramTypesHandle); - - traceBreakpoint(jni, getClassOrSingleProxyInterface(jni, declaring), getClassOrSingleProxyInterface(jni, declaring), callerClass, "invokeMethod", declaring.notEqual(nullHandle()), + traceReflectBreakpoint(jni, getClassOrSingleProxyInterface(jni, declaring), getClassOrSingleProxyInterface(jni, declaring), callerClass, "invokeMethod", declaring.notEqual(nullHandle()), state.getFullStackTraceOrNull(), name, paramTypes); - /* * Calling Class.newInstance through Method.invoke should register the class for reflective * instantiation @@ -422,7 +430,7 @@ private static boolean handleInvokeMethod(JNIEnvironment jni, @SuppressWarnings( if (isInvoke && isClassNewInstance(jni, declaring, name)) { JNIObjectHandle clazz = getObjectArgument(1); JNIMethodId result = newInstanceMethodID(jni, clazz); - traceBreakpoint(jni, clazz, nullHandle(), callerClass, "newInstance", result.notEqual(nullHandle()), state.getFullStackTraceOrNull()); + traceReflectBreakpoint(jni, clazz, nullHandle(), callerClass, "newInstance", result.notEqual(nullHandle()), state.getFullStackTraceOrNull()); } return true; } @@ -460,8 +468,7 @@ private static boolean handleInvokeConstructor(JNIEnvironment jni, @SuppressWarn paramTypesHandle = nullHandle(); } Object paramTypes = getClassArrayNames(jni, paramTypesHandle); - - traceBreakpoint(jni, getClassOrSingleProxyInterface(jni, declaring), getClassOrSingleProxyInterface(jni, declaring), callerClass, "invokeConstructor", declaring.notEqual(nullHandle()), + traceReflectBreakpoint(jni, getClassOrSingleProxyInterface(jni, declaring), getClassOrSingleProxyInterface(jni, declaring), callerClass, "invokeConstructor", declaring.notEqual(nullHandle()), state.getFullStackTraceOrNull(), paramTypes); return true; } @@ -470,7 +477,7 @@ private static boolean newInstance(JNIEnvironment jni, Breakpoint bp, Intercepte JNIObjectHandle callerClass = state.getDirectCallerClass(); JNIObjectHandle self = getObjectArgument(0); JNIMethodId result = newInstanceMethodID(jni, self); - traceBreakpoint(jni, self, nullHandle(), callerClass, bp.specification.methodName, result.notEqual(nullHandle()), state.getFullStackTraceOrNull()); + traceReflectBreakpoint(jni, self, nullHandle(), callerClass, bp.specification.methodName, result.notEqual(nullHandle()), state.getFullStackTraceOrNull()); return true; } @@ -529,7 +536,7 @@ private static boolean newArrayInstance0(JNIEnvironment jni, Breakpoint bp, JNIV } } String resultClassName = getClassNameOr(jni, resultClass, null, Tracer.UNKNOWN_VALUE); - traceBreakpoint(jni, bp.clazz, nullHandle(), callerClass, bp.specification.methodName, result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), resultClassName); + traceReflectBreakpoint(jni, bp.clazz, nullHandle(), callerClass, bp.specification.methodName, result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), resultClassName); return true; } @@ -561,7 +568,7 @@ private static boolean handleGetResources(JNIEnvironment jni, Breakpoint bp, boo selfClazz = nullHandle(); } } - traceBreakpoint(jni, selfClazz, nullHandle(), callerClass, bp.specification.methodName, result, state.getFullStackTraceOrNull(), fromJniString(jni, name)); + traceReflectBreakpoint(jni, selfClazz, nullHandle(), callerClass, bp.specification.methodName, result, state.getFullStackTraceOrNull(), fromJniString(jni, name)); return true; } @@ -592,7 +599,7 @@ private static boolean handleGetSystemResources(JNIEnvironment jni, Breakpoint b if (result && returnsEnumeration) { result = hasEnumerationElements(jni, returnValue); } - traceBreakpoint(jni, nullHandle(), nullHandle(), callerClass, bp.specification.methodName, result, state.getFullStackTraceOrNull(), fromJniString(jni, name)); + traceReflectBreakpoint(jni, nullHandle(), nullHandle(), callerClass, bp.specification.methodName, result, state.getFullStackTraceOrNull(), fromJniString(jni, name)); return true; } @@ -606,7 +613,7 @@ private static boolean newProxyInstance(JNIEnvironment jni, Breakpoint bp, Inter if (clearException(jni)) { result = false; } - traceBreakpoint(jni, nullHandle(), nullHandle(), callerClass, bp.specification.methodName, result, state.getFullStackTraceOrNull(), Tracer.UNKNOWN_VALUE, ifaceNames, + traceReflectBreakpoint(jni, nullHandle(), nullHandle(), callerClass, bp.specification.methodName, result, state.getFullStackTraceOrNull(), Tracer.UNKNOWN_VALUE, ifaceNames, Tracer.UNKNOWN_VALUE); return true; } @@ -620,7 +627,7 @@ private static boolean getProxyClass(JNIEnvironment jni, Breakpoint bp, Intercep if (clearException(jni)) { result = false; } - traceBreakpoint(jni, nullHandle(), nullHandle(), callerClass, bp.specification.methodName, result, state.getFullStackTraceOrNull(), Tracer.UNKNOWN_VALUE, ifaceNames); + traceReflectBreakpoint(jni, nullHandle(), nullHandle(), callerClass, bp.specification.methodName, result, state.getFullStackTraceOrNull(), Tracer.UNKNOWN_VALUE, ifaceNames); return true; } @@ -659,7 +666,7 @@ private static boolean getBundleImplJDK8OrEarlier(JNIEnvironment jni, Breakpoint } else { bundleInfo = extractBundleInfo(jni, result); } - traceBreakpoint(jni, nullHandle(), nullHandle(), callerClass, "getBundleImplJDK8OrEarlier", result.notEqual(nullHandle()), + traceReflectBreakpoint(jni, nullHandle(), nullHandle(), callerClass, "getBundleImplJDK8OrEarlier", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), fromJniString(jni, baseName), Tracer.UNKNOWN_VALUE, Tracer.UNKNOWN_VALUE, Tracer.UNKNOWN_VALUE, bundleInfo.classNames, bundleInfo.locales); return true; } @@ -686,7 +693,7 @@ private static boolean getBundleImplJDK11OrLater(JNIEnvironment jni, Breakpoint } else { bundleInfo = extractBundleInfo(jni, result); } - traceBreakpoint(jni, nullHandle(), nullHandle(), callerClass, "getBundleImplJDK11OrLater", result.notEqual(nullHandle()), + traceReflectBreakpoint(jni, nullHandle(), nullHandle(), callerClass, "getBundleImplJDK11OrLater", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), Tracer.UNKNOWN_VALUE, Tracer.UNKNOWN_VALUE, fromJniString(jni, baseName), Tracer.UNKNOWN_VALUE, Tracer.UNKNOWN_VALUE, bundleInfo.classNames, bundleInfo.locales); return true; @@ -786,7 +793,7 @@ private static boolean loadClass(JNIEnvironment jni, Breakpoint bp, InterceptedS if (clearException(jni)) { clazz = nullHandle(); } - traceBreakpoint(jni, bp.clazz, nullHandle(), callerClass, bp.specification.methodName, clazz.notEqual(nullHandle()), state.getFullStackTraceOrNull(), className); + traceReflectBreakpoint(jni, bp.clazz, nullHandle(), callerClass, bp.specification.methodName, clazz.notEqual(nullHandle()), state.getFullStackTraceOrNull(), className); return true; } @@ -887,7 +894,7 @@ private static boolean methodMethodHandle(JNIEnvironment jni, JNIObjectHandle de JNIObjectHandle result, JNIMethodId[] stackTrace) { String name = fromJniString(jni, nameHandle); Object paramTypes = getClassArrayNames(jni, paramTypesHandle); - traceBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findMethodHandle", result.notEqual(nullHandle()), stackTrace, name, paramTypes); + traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findMethodHandle", result.notEqual(nullHandle()), stackTrace, name, paramTypes); return true; } @@ -901,7 +908,7 @@ private static boolean findConstructorHandle(JNIEnvironment jni, Breakpoint bp, result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessException); Object paramTypes = getClassArrayNames(jni, getParamTypes(jni, methodType)); - traceBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findConstructorHandle", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), paramTypes); + traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findConstructorHandle", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), paramTypes); return true; } @@ -924,7 +931,7 @@ private static boolean findFieldHandle(JNIEnvironment jni, Breakpoint bp, Interc result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessException); String name = fromJniString(jni, fieldName); - traceBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findFieldHandle", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), name); + traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findFieldHandle", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), name); return true; } @@ -937,7 +944,7 @@ private static boolean findClass(JNIEnvironment jni, Breakpoint bp, InterceptedS result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessException); String name = fromJniString(jni, className); - traceBreakpoint(jni, bp.clazz, nullHandle(), callerClass, "findClass", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), name); + traceReflectBreakpoint(jni, bp.clazz, nullHandle(), callerClass, "findClass", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), name); return true; } @@ -960,7 +967,7 @@ private static boolean unreflectField(JNIEnvironment jni, Breakpoint bp, Interce } String fieldName = fromJniString(jni, fieldNameHandle); - traceBreakpoint(jni, declaringClass, nullHandle(), callerClass, "unreflectField", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), fieldName); + traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "unreflectField", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), fieldName); return true; } @@ -977,9 +984,9 @@ private static boolean asInterfaceInstance(JNIEnvironment jni, Breakpoint bp, In intfcNameHandle = nullHandle(); } String intfcName = fromJniString(jni, intfcNameHandle); - traceBreakpoint(jni, intfc, nullHandle(), callerClass, "asInterfaceInstance", result.notEqual(nullHandle()), state.getFullStackTraceOrNull()); + traceReflectBreakpoint(jni, intfc, nullHandle(), callerClass, "asInterfaceInstance", result.notEqual(nullHandle()), state.getFullStackTraceOrNull()); String[] intfcNames = new String[]{intfcName}; - traceBreakpoint(jni, nullHandle(), nullHandle(), callerClass, "newMethodHandleProxyInstance", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), (Object) intfcNames); + traceReflectBreakpoint(jni, nullHandle(), nullHandle(), callerClass, "newMethodHandleProxyInstance", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), (Object) intfcNames); return true; } @@ -999,7 +1006,7 @@ private static boolean constantBootstrapGetStaticFinal(JNIEnvironment jni, Break result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessError); String name = fromJniString(jni, fieldName); - traceBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findFieldHandle", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), name); + traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findFieldHandle", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), name); return true; } @@ -1034,7 +1041,7 @@ private static boolean methodTypeFromDescriptor(JNIEnvironment jni, Breakpoint b } } - traceBreakpoint(jni, nullHandle(), nullHandle(), callerClass, "methodTypeDescriptor", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), types); + traceReflectBreakpoint(jni, nullHandle(), nullHandle(), callerClass, "methodTypeDescriptor", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), types); return true; } @@ -1056,67 +1063,74 @@ private static JNIObjectHandle shouldIncludeMethod(JNIEnvironment jni, JNIObject return result; } + /** + * We have to find a class that captures a lambda function so it can be registered by the agent. + * We have to get a SerializedLambda instance first. After that we get a lambda capturing class + * from that instance using JNIHandleSet#getFieldId to get field id and JNIObjectHandle#invoke + * on to get that field value. We get a name of the capturing class and tell the agent to + * register it. + */ + private static boolean serializedLambdaReadResolve(JNIEnvironment jni, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state) { + JNIObjectHandle serializedLambdaInstance = getObjectArgument(0); + JNIObjectHandle capturingClass = jniFunctions().getGetObjectField().invoke(jni, serializedLambdaInstance, + agent.handles().javaLangInvokeSerializedLambdaCapturingClass); + + String capturingClassName = getClassNameOrNull(jni, capturingClass); + boolean validCapturingClass = nullHandle().notEqual(capturingClass); + + traceSerializeBreakpoint(jni, "SerializedLambda.readResolve", validCapturingClass, state.getFullStackTraceOrNull(), capturingClassName); + return true; + } + private static boolean objectStreamClassConstructor(JNIEnvironment jni, Breakpoint bp, InterceptedState state) { JNIObjectHandle serializeTargetClass = getObjectArgument(1); - String serializeTargetClassName = getClassNameOrNull(jni, serializeTargetClass); + if (Support.isSerializable(jni, serializeTargetClass)) { + String serializeTargetClassName = getClassNameOrNull(jni, serializeTargetClass); - JNIObjectHandle objectStreamClassInstance = Support.newObjectL(jni, bp.clazz, bp.method, serializeTargetClass); - boolean validObjectStreamClassInstance = nullHandle().notEqual(objectStreamClassInstance); - if (clearException(jni)) { - validObjectStreamClassInstance = false; - } - - // Skip Lambda class serialization - if (serializeTargetClassName.contains("$$Lambda$")) { - return true; - } + JNIObjectHandle objectStreamClassInstance = Support.newObjectL(jni, bp.clazz, bp.method, serializeTargetClass); + boolean validObjectStreamClassInstance = nullHandle().notEqual(objectStreamClassInstance); + if (clearException(jni)) { + validObjectStreamClassInstance = false; + } - List transitiveSerializeTargets = new ArrayList<>(); - transitiveSerializeTargets.add(serializeTargetClassName); + List transitiveSerializeTargets = new ArrayList<>(); + transitiveSerializeTargets.add(serializeTargetClassName); - /* - * When the ObjectStreamClass instance is created for the given serializeTargetClass, some - * additional ObjectStreamClass instances (usually the super classes) are created - * recursively. Call ObjectStreamClass.getClassDataLayout0() can get all of them. - */ - JNIMethodId getClassDataLayout0MId = agent.handles().getJavaIoObjectStreamClassGetClassDataLayout0(jni, bp.clazz); - JNIObjectHandle dataLayoutArray = Support.callObjectMethod(jni, objectStreamClassInstance, getClassDataLayout0MId); - if (!clearException(jni) && nullHandle().notEqual(dataLayoutArray)) { - int length = jniFunctions().getGetArrayLength().invoke(jni, dataLayoutArray); - // If only 1 element is got from getClassDataLayout0(). it is base ObjectStreamClass - // instance itself. - if (!clearException(jni) && length > 1) { - JNIFieldId hasDataFId = agent.handles().getJavaIOObjectStreamClassClassDataSlotHasData(jni); - JNIFieldId descFId = agent.handles().getJavaIOObjectStreamClassClassDataSlotDesc(jni); - JNIMethodId javaIoObjectStreamClassForClassMId = agent.handles().getJavaIoObjectStreamClassForClass(jni, bp.clazz); - for (int i = 0; i < length; i++) { - JNIObjectHandle classDataSlot = jniFunctions().getGetObjectArrayElement().invoke(jni, dataLayoutArray, i); - boolean hasData = jniFunctions().getGetBooleanField().invoke(jni, classDataSlot, hasDataFId); - if (hasData) { - JNIObjectHandle oscInstanceInSlot = jniFunctions().getGetObjectField().invoke(jni, classDataSlot, descFId); - if (!jniFunctions().getIsSameObject().invoke(jni, oscInstanceInSlot, objectStreamClassInstance)) { - JNIObjectHandle oscClazz = Support.callObjectMethod(jni, oscInstanceInSlot, javaIoObjectStreamClassForClassMId); - String oscClassName = getClassNameOrNull(jni, oscClazz); - transitiveSerializeTargets.add(oscClassName); + /* + * When the ObjectStreamClass instance is created for the given serializeTargetClass, + * some additional ObjectStreamClass instances (usually the super classes) are created + * recursively. Call ObjectStreamClass.getClassDataLayout0() can get all of them. + */ + JNIMethodId getClassDataLayout0MId = agent.handles().getJavaIoObjectStreamClassGetClassDataLayout0(jni, bp.clazz); + JNIObjectHandle dataLayoutArray = Support.callObjectMethod(jni, objectStreamClassInstance, getClassDataLayout0MId); + + if (!clearException(jni) && nullHandle().notEqual(dataLayoutArray)) { + int length = jniFunctions().getGetArrayLength().invoke(jni, dataLayoutArray); + // If only 1 element is got from getClassDataLayout0(). it is base ObjectStreamClass + // instance itself. + if (!clearException(jni) && length > 1) { + JNIFieldId hasDataFId = agent.handles().getJavaIOObjectStreamClassClassDataSlotHasData(jni); + JNIFieldId descFId = agent.handles().getJavaIOObjectStreamClassClassDataSlotDesc(jni); + JNIMethodId javaIoObjectStreamClassForClassMId = agent.handles().getJavaIoObjectStreamClassForClass(jni, bp.clazz); + for (int i = 0; i < length; i++) { + JNIObjectHandle classDataSlot = jniFunctions().getGetObjectArrayElement().invoke(jni, dataLayoutArray, i); + boolean hasData = jniFunctions().getGetBooleanField().invoke(jni, classDataSlot, hasDataFId); + if (hasData) { + JNIObjectHandle oscInstanceInSlot = jniFunctions().getGetObjectField().invoke(jni, classDataSlot, descFId); + if (!jniFunctions().getIsSameObject().invoke(jni, oscInstanceInSlot, objectStreamClassInstance)) { + JNIObjectHandle oscClazz = Support.callObjectMethod(jni, oscInstanceInSlot, javaIoObjectStreamClassForClassMId); + if (Support.isSerializable(jni, oscClazz)) { + String oscClassName = getClassNameOrNull(jni, oscClazz); + transitiveSerializeTargets.add(oscClassName); + } + } } } } } - } - for (String className : transitiveSerializeTargets) { - if (tracer != null) { - tracer.traceCall("serialization", - "ObjectStreamClass.", - null, - null, - null, - validObjectStreamClassInstance, - state.getFullStackTraceOrNull(), - /*- String serializationTargetClass, String customTargetConstructorClass */ - className, null); - - guarantee(!testException(jni)); + for (String className : transitiveSerializeTargets) { + traceSerializeBreakpoint(jni, "ObjectStreamClass.", validObjectStreamClassInstance, state.getFullStackTraceOrNull(), className, null); } } return true; @@ -1132,31 +1146,15 @@ private static boolean objectStreamClassConstructor(JNIEnvironment jni, Breakpoi */ private static boolean customTargetConstructorSerialization(JNIEnvironment jni, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state) { JNIObjectHandle serializeTargetClass = getObjectArgument(1); - String serializeTargetClassName = getClassNameOrNull(jni, serializeTargetClass); + if (Support.isSerializable(jni, serializeTargetClass)) { + String serializeTargetClassName = getClassNameOrNull(jni, serializeTargetClass); - // Skip Lambda class serialization. - if (serializeTargetClassName.contains("$$Lambda$")) { - return true; - } - - JNIObjectHandle customConstructorObj = getObjectArgument(2); - JNIObjectHandle customConstructorClass = jniFunctions().getGetObjectClass().invoke(jni, customConstructorObj); - JNIMethodId getDeclaringClassNameMethodID = agent.handles().getJavaLangReflectConstructorDeclaringClassName(jni, customConstructorClass); - JNIObjectHandle declaredClassNameObj = Support.callObjectMethod(jni, customConstructorObj, getDeclaringClassNameMethodID); - String customConstructorClassName = fromJniString(jni, declaredClassNameObj); - - if (tracer != null) { - tracer.traceCall("serialization", - "ObjectStreamClass.", - null, - null, - null, - true, - state.getFullStackTraceOrNull(), - /*- String serializationTargetClass, String customTargetConstructorClass */ - serializeTargetClassName, customConstructorClassName); - - guarantee(!testException(jni)); + JNIObjectHandle customConstructorObj = getObjectArgument(2); + JNIObjectHandle customConstructorClass = jniFunctions().getGetObjectClass().invoke(jni, customConstructorObj); + JNIMethodId getDeclaringClassNameMethodID = agent.handles().getJavaLangReflectConstructorDeclaringClassName(jni, customConstructorClass); + JNIObjectHandle declaredClassNameObj = Support.callObjectMethod(jni, customConstructorObj, getDeclaringClassNameMethodID); + String customConstructorClassName = fromJniString(jni, declaredClassNameObj); + traceSerializeBreakpoint(jni, "ObjectStreamClass.", true, state.getFullStackTraceOrNull(), serializeTargetClassName, customConstructorClassName); } return true; } @@ -1493,6 +1491,7 @@ private interface BreakpointHandler { brk("java/lang/reflect/Proxy", "newProxyInstance", "(Ljava/lang/ClassLoader;[Ljava/lang/Class;Ljava/lang/reflect/InvocationHandler;)Ljava/lang/Object;", BreakpointInterceptor::newProxyInstance), + brk("java/lang/invoke/SerializedLambda", "readResolve", "()Ljava/lang/Object;", BreakpointInterceptor::serializedLambdaReadResolve), brk("java/io/ObjectStreamClass", "", "(Ljava/lang/Class;)V", BreakpointInterceptor::objectStreamClassConstructor), brk("jdk/internal/reflect/ReflectionFactory", "newConstructorForSerialization", diff --git a/substratevm/src/com.oracle.svm.agent/src/com/oracle/svm/agent/NativeImageAgentJNIHandleSet.java b/substratevm/src/com.oracle.svm.agent/src/com/oracle/svm/agent/NativeImageAgentJNIHandleSet.java index ee9103f4f690..bebe65427ff8 100644 --- a/substratevm/src/com.oracle.svm.agent/src/com/oracle/svm/agent/NativeImageAgentJNIHandleSet.java +++ b/substratevm/src/com.oracle.svm.agent/src/com/oracle/svm/agent/NativeImageAgentJNIHandleSet.java @@ -81,6 +81,8 @@ public class NativeImageAgentJNIHandleSet extends JNIHandleSet { private JNIFieldId javaUtilResourceBundleParentField; private JNIMethodId javaUtilResourceBundleGetLocale; + final JNIFieldId javaLangInvokeSerializedLambdaCapturingClass; + NativeImageAgentJNIHandleSet(JNIEnvironment env) { super(env); javaLangClass = newClassGlobalRef(env, "java/lang/Class"); @@ -113,6 +115,9 @@ public class NativeImageAgentJNIHandleSet extends JNIHandleSet { javaLangIllegalAccessError = newClassGlobalRef(env, "java/lang/IllegalAccessError"); javaLangInvokeWrongMethodTypeException = newClassGlobalRef(env, "java/lang/invoke/WrongMethodTypeException"); javaLangIllegalArgumentException = newClassGlobalRef(env, "java/lang/IllegalArgumentException"); + + JNIObjectHandle serializedLambda = findClass(env, "java/lang/invoke/SerializedLambda"); + javaLangInvokeSerializedLambdaCapturingClass = getFieldId(env, serializedLambda, "capturingClass", "Ljava/lang/Class;", false); } JNIMethodId getJavaLangReflectExecutableGetParameterTypes(JNIEnvironment env) { diff --git a/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/config/SerializationConfiguration.java b/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/config/SerializationConfiguration.java index f6fe9b907841..2cf78bb3cde1 100644 --- a/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/config/SerializationConfiguration.java +++ b/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/config/SerializationConfiguration.java @@ -32,6 +32,8 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import com.oracle.svm.configure.json.JsonPrintable; +import org.graalvm.compiler.java.LambdaUtils; import org.graalvm.nativeimage.impl.ConfigurationCondition; import org.graalvm.nativeimage.impl.RuntimeSerializationSupport; @@ -40,36 +42,60 @@ public class SerializationConfiguration implements ConfigurationBase, RuntimeSerializationSupport { - private final Set serializations = ConcurrentHashMap.newKeySet(); + private final Set serializationTypes = ConcurrentHashMap.newKeySet(); + private final Set lambdaSerializationCapturingTypes = ConcurrentHashMap.newKeySet(); public SerializationConfiguration() { } public SerializationConfiguration(SerializationConfiguration other) { - serializations.addAll(other.serializations); + serializationTypes.addAll(other.serializationTypes); + lambdaSerializationCapturingTypes.addAll(other.lambdaSerializationCapturingTypes); } public void removeAll(SerializationConfiguration other) { - serializations.removeAll(other.serializations); + serializationTypes.removeAll(other.serializationTypes); + lambdaSerializationCapturingTypes.removeAll(other.lambdaSerializationCapturingTypes); } public boolean contains(ConfigurationCondition condition, String serializationTargetClass, String customTargetConstructorClass) { - return serializations.contains(createConfigurationType(condition, serializationTargetClass, customTargetConstructorClass)); + return serializationTypes.contains(createConfigurationType(condition, serializationTargetClass, customTargetConstructorClass)) || + lambdaSerializationCapturingTypes.contains(createLambdaCapturingClassConfigurationType(condition, serializationTargetClass)); } @Override public void printJson(JsonWriter writer) throws IOException { - writer.append('[').indent(); + writer.append('{').indent().newline(); + List listOfCapturedClasses = new ArrayList<>(serializationTypes); + Collections.sort(listOfCapturedClasses); + printSerializationClasses(writer, "types", listOfCapturedClasses); + writer.append(",").newline(); + List listOfCapturingClasses = new ArrayList<>(lambdaSerializationCapturingTypes); + listOfCapturingClasses.sort(new SerializationConfigurationLambdaCapturingType.SerializationConfigurationLambdaCapturingTypesComparator()); + printSerializationClasses(writer, "lambdaCapturingTypes", listOfCapturingClasses); + writer.unindent().newline(); + writer.append('}'); + } + + private static void printSerializationClasses(JsonWriter writer, String types, List serializationConfigurationTypes) throws IOException { + writer.quote(types).append(":"); + writer.append('['); + writer.indent(); + + printSerializationTypes(serializationConfigurationTypes, writer); + + writer.unindent().newline(); + writer.append("]"); + } + + private static void printSerializationTypes(List serializationConfigurationTypes, JsonWriter writer) throws IOException { String prefix = ""; - List list = new ArrayList<>(serializations); - Collections.sort(list); - for (SerializationConfigurationType type : list) { + + for (JsonPrintable type : serializationConfigurationTypes) { writer.append(prefix).newline(); type.printJson(writer); prefix = ","; } - writer.unindent().newline(); - writer.append(']'); } @Override @@ -91,12 +117,17 @@ public void registerWithTargetConstructorClass(ConfigurationCondition condition, @Override public void registerWithTargetConstructorClass(ConfigurationCondition condition, String className, String customTargetConstructorClassName) { - serializations.add(createConfigurationType(condition, className, customTargetConstructorClassName)); + serializationTypes.add(createConfigurationType(condition, className, customTargetConstructorClassName)); + } + + @Override + public void registerLambdaCapturingClass(ConfigurationCondition condition, String lambdaCapturingClassName) { + lambdaSerializationCapturingTypes.add(createLambdaCapturingClassConfigurationType(condition, lambdaCapturingClassName.split(LambdaUtils.LAMBDA_SPLIT_PATTERN)[0])); } @Override public boolean isEmpty() { - return serializations.isEmpty(); + return serializationTypes.isEmpty() && lambdaSerializationCapturingTypes.isEmpty(); } private static SerializationConfigurationType createConfigurationType(ConfigurationCondition condition, String className, String customTargetConstructorClassName) { @@ -104,4 +135,9 @@ private static SerializationConfigurationType createConfigurationType(Configurat String convertedCustomTargetConstructorClassName = customTargetConstructorClassName == null ? null : SignatureUtil.toInternalClassName(customTargetConstructorClassName); return new SerializationConfigurationType(condition, convertedClassName, convertedCustomTargetConstructorClassName); } + + private static SerializationConfigurationLambdaCapturingType createLambdaCapturingClassConfigurationType(ConfigurationCondition condition, String className) { + String convertedClassName = SignatureUtil.toInternalClassName(className); + return new SerializationConfigurationLambdaCapturingType(condition, convertedClassName); + } } diff --git a/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/config/SerializationConfigurationLambdaCapturingType.java b/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/config/SerializationConfigurationLambdaCapturingType.java new file mode 100644 index 000000000000..670b1484f07a --- /dev/null +++ b/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/config/SerializationConfigurationLambdaCapturingType.java @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2021, 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.svm.configure.config; + +import java.io.IOException; +import java.util.Comparator; +import java.util.Objects; + +import com.oracle.svm.configure.json.JsonPrintable; +import org.graalvm.nativeimage.impl.ConfigurationCondition; + +import com.oracle.svm.configure.json.JsonWriter; +import com.oracle.svm.core.configure.SerializationConfigurationParser; + +public class SerializationConfigurationLambdaCapturingType implements JsonPrintable { + private final ConfigurationCondition condition; + private final String qualifiedJavaName; + + public SerializationConfigurationLambdaCapturingType(ConfigurationCondition condition, String qualifiedJavaName) { + assert qualifiedJavaName.indexOf('/') == -1 : "Requires qualified Java name, not the internal representation"; + Objects.requireNonNull(condition); + this.condition = condition; + Objects.requireNonNull(qualifiedJavaName); + this.qualifiedJavaName = qualifiedJavaName; + } + + @Override + public void printJson(JsonWriter writer) throws IOException { + writer.append('{').indent().newline(); + ConfigurationConditionPrintable.printConditionAttribute(condition, writer); + + writer.quote(SerializationConfigurationParser.NAME_KEY).append(":").quote(qualifiedJavaName); + writer.unindent().newline().append('}'); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SerializationConfigurationLambdaCapturingType that = (SerializationConfigurationLambdaCapturingType) o; + return condition.equals(that.condition) && + qualifiedJavaName.equals(that.qualifiedJavaName); + } + + @Override + public int hashCode() { + return Objects.hash(condition, qualifiedJavaName); + } + + public static final class SerializationConfigurationLambdaCapturingTypesComparator implements Comparator { + + @Override + public int compare(SerializationConfigurationLambdaCapturingType o1, SerializationConfigurationLambdaCapturingType o2) { + int compareName = o1.qualifiedJavaName.compareTo(o2.qualifiedJavaName); + if (compareName != 0) { + return compareName; + } + return o1.condition.compareTo(o2.condition); + } + } +} diff --git a/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/config/SerializationConfigurationType.java b/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/config/SerializationConfigurationType.java index 02de8118fc3b..3e8e6623c034 100644 --- a/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/config/SerializationConfigurationType.java +++ b/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/config/SerializationConfigurationType.java @@ -40,9 +40,9 @@ public class SerializationConfigurationType implements JsonPrintable, Comparable private final String qualifiedCustomTargetConstructorJavaName; public SerializationConfigurationType(ConfigurationCondition condition, String qualifiedJavaName, String qualifiedCustomTargetConstructorJavaName) { - assert qualifiedJavaName.indexOf('/') == -1 : "Requires qualified Java name, not internal representation"; + assert qualifiedJavaName.indexOf('/') == -1 : "Requires qualified Java name, not the internal representation"; assert !qualifiedJavaName.startsWith("[") : "Requires Java source array syntax, for example java.lang.String[]"; - assert qualifiedCustomTargetConstructorJavaName == null || qualifiedCustomTargetConstructorJavaName.indexOf('/') == -1 : "Requires qualified Java name, not internal representation"; + assert qualifiedCustomTargetConstructorJavaName == null || qualifiedCustomTargetConstructorJavaName.indexOf('/') == -1 : "Requires qualified Java name, not the internal representation"; assert qualifiedCustomTargetConstructorJavaName == null || !qualifiedCustomTargetConstructorJavaName.startsWith("[") : "Requires Java source array syntax, for example java.lang.String[]"; Objects.requireNonNull(condition); this.condition = condition; diff --git a/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/trace/SerializationProcessor.java b/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/trace/SerializationProcessor.java index df3a0df87b0c..82f7e60ed9e8 100644 --- a/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/trace/SerializationProcessor.java +++ b/substratevm/src/com.oracle.svm.configure/src/com/oracle/svm/configure/trace/SerializationProcessor.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Map; +import org.graalvm.compiler.java.LambdaUtils; import org.graalvm.nativeimage.impl.ConfigurationCondition; import com.oracle.svm.configure.config.SerializationConfiguration; @@ -54,6 +55,7 @@ void processEntry(Map entry) { } String function = (String) entry.get("function"); List args = (List) entry.get("args"); + if ("ObjectStreamClass.".equals(function)) { expectSize(args, 2); @@ -61,7 +63,21 @@ void processEntry(Map entry) { return; } - serializationConfiguration.registerWithTargetConstructorClass(condition, (String) args.get(0), (String) args.get(1)); + String className = (String) args.get(0); + + if (className.contains(LambdaUtils.LAMBDA_CLASS_NAME_SUBSTRING)) { + serializationConfiguration.registerLambdaCapturingClass(condition, className); + } else { + serializationConfiguration.registerWithTargetConstructorClass(condition, className, (String) args.get(1)); + } + } else if ("SerializedLambda.readResolve".equals(function)) { + expectSize(args, 1); + + if (advisor.shouldIgnore(LazyValueUtils.lazyValue((String) args.get(0)), LazyValueUtils.lazyValue(null))) { + return; + } + + serializationConfiguration.registerLambdaCapturingClass(condition, (String) args.get(0)); } } } diff --git a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/configure/SerializationConfigurationParser.java b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/configure/SerializationConfigurationParser.java index 626b83369576..75e73ae9dd69 100644 --- a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/configure/SerializationConfigurationParser.java +++ b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/configure/SerializationConfigurationParser.java @@ -25,21 +25,24 @@ */ package com.oracle.svm.core.configure; +import com.oracle.svm.core.util.json.JSONParser; +import com.oracle.svm.core.util.json.JSONParserException; +import org.graalvm.nativeimage.impl.ConfigurationCondition; +import org.graalvm.nativeimage.impl.RuntimeSerializationSupport; + import java.io.IOException; import java.io.Reader; import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; -import org.graalvm.nativeimage.impl.ConfigurationCondition; -import org.graalvm.nativeimage.impl.RuntimeSerializationSupport; - -import com.oracle.svm.core.util.json.JSONParser; - public class SerializationConfigurationParser extends ConfigurationParser { public static final String NAME_KEY = "name"; public static final String CUSTOM_TARGET_CONSTRUCTOR_CLASS_KEY = "customTargetConstructorClass"; + private static final String SERIALIZATION_TYPES_KEY = "types"; + private static final String LAMBDA_CAPTURING_SERIALIZATION_TYPES_KEY = "lambdaCapturingTypes"; private final RuntimeSerializationSupport serializationSupport; @@ -52,18 +55,52 @@ public SerializationConfigurationParser(RuntimeSerializationSupport serializatio public void parseAndRegister(Reader reader) throws IOException { JSONParser parser = new JSONParser(reader); Object json = parser.parse(); - for (Object serializationKey : asList(json, "first level of document must be an array of serialization lists")) { - parseSerializationDescriptorObject(asMap(serializationKey, "second level of document must be serialization descriptor objects")); + if (json instanceof List) { + parseOldConfiguration(asList(json, "first level of document must be an array of serialization lists")); + } else if (json instanceof Map) { + parseNewConfiguration(asMap(json, "first level of document must be a map of serialization types")); + } else { + throw new JSONParserException("first level of document must either be an array of serialization lists or a map of serialization types"); } } - private void parseSerializationDescriptorObject(Map data) { - checkAttributes(data, "serialization descriptor object", Collections.singleton(NAME_KEY), Arrays.asList(CUSTOM_TARGET_CONSTRUCTOR_CLASS_KEY, CONDITIONAL_KEY)); + private void parseOldConfiguration(List listOfSerializationConfigurationObjects) { + parseSerializationTypes(asList(listOfSerializationConfigurationObjects, "second level of document must be serialization descriptor objects"), false); + } + + private void parseNewConfiguration(Map listOfSerializationConfigurationObjects) { + if (!listOfSerializationConfigurationObjects.containsKey(SERIALIZATION_TYPES_KEY) || !listOfSerializationConfigurationObjects.containsKey(LAMBDA_CAPTURING_SERIALIZATION_TYPES_KEY)) { + throw new JSONParserException("second level of document must be arrays of serialization descriptor objects"); + } + + parseSerializationTypes(asList(listOfSerializationConfigurationObjects.get(SERIALIZATION_TYPES_KEY), "types must be an array of serialization descriptor objects"), false); + parseSerializationTypes( + asList(listOfSerializationConfigurationObjects.get(LAMBDA_CAPTURING_SERIALIZATION_TYPES_KEY), "lambdaCapturingTypes must be an array of serialization descriptor objects"), + true); + } + + private void parseSerializationTypes(List listOfSerializationTypes, boolean lambdaCapturingTypes) { + for (Object serializationType : listOfSerializationTypes) { + parseSerializationDescriptorObject(asMap(serializationType, "third level of document must be serialization descriptor objects"), lambdaCapturingTypes); + } + } + + private void parseSerializationDescriptorObject(Map data, boolean lambdaCapturingType) { + if (lambdaCapturingType) { + checkAttributes(data, "serialization descriptor object", Collections.singleton(NAME_KEY), Collections.singleton(CONDITIONAL_KEY)); + } else { + checkAttributes(data, "serialization descriptor object", Collections.singleton(NAME_KEY), Arrays.asList(CUSTOM_TARGET_CONSTRUCTOR_CLASS_KEY, CONDITIONAL_KEY)); + } + ConfigurationCondition unresolvedCondition = parseCondition(data); String targetSerializationClass = asString(data.get(NAME_KEY)); - Object optionalCustomCtorValue = data.get(CUSTOM_TARGET_CONSTRUCTOR_CLASS_KEY); - String customTargetConstructorClass = optionalCustomCtorValue != null ? asString(optionalCustomCtorValue) : null; - serializationSupport.registerWithTargetConstructorClass(unresolvedCondition, targetSerializationClass, customTargetConstructorClass); + if (lambdaCapturingType) { + serializationSupport.registerLambdaCapturingClass(unresolvedCondition, targetSerializationClass); + } else { + Object optionalCustomCtorValue = data.get(CUSTOM_TARGET_CONSTRUCTOR_CLASS_KEY); + String customTargetConstructorClass = optionalCustomCtorValue != null ? asString(optionalCustomCtorValue) : null; + serializationSupport.registerWithTargetConstructorClass(unresolvedCondition, targetSerializationClass, customTargetConstructorClass); + } } } diff --git a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/configure/doc-files/SerializationConfigurationFilesHelp.txt b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/configure/doc-files/SerializationConfigurationFilesHelp.txt index 5905008f3142..418ec78b947d 100644 --- a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/configure/doc-files/SerializationConfigurationFilesHelp.txt +++ b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/configure/doc-files/SerializationConfigurationFilesHelp.txt @@ -10,6 +10,17 @@ Example: } ] +For deserializing lambda classes, the capturing class of the lambda needs to be specified in a separate section of the configuration file, for example: + + [ + types: [ + {"name":"java.lang.Object"} + ], + lambdaCapturingTypes: [ + {"name":"java.util.Comparator"} + ] + ] + This JSON file format is also used for the serialization deny list. In rare cases an application might explicitly make calls to diff --git a/substratevm/src/com.oracle.svm.hosted/src/com/oracle/svm/hosted/lambda/StableLambdaProxyNameFeature.java b/substratevm/src/com.oracle.svm.hosted/src/com/oracle/svm/hosted/lambda/StableLambdaProxyNameFeature.java index 0b7e9a793b8e..109bc7ce7bb4 100644 --- a/substratevm/src/com.oracle.svm.hosted/src/com/oracle/svm/hosted/lambda/StableLambdaProxyNameFeature.java +++ b/substratevm/src/com.oracle.svm.hosted/src/com/oracle/svm/hosted/lambda/StableLambdaProxyNameFeature.java @@ -67,7 +67,7 @@ private static boolean checkLambdaNames(List types) { Set lambdaNames = new HashSet<>(); types.stream() .map(AnalysisType::getName) - .filter(x -> x.contains("$$Lambda$")) + .filter(x -> x.contains(LambdaUtils.LAMBDA_CLASS_NAME_SUBSTRING)) .forEach(name -> { if (lambdaNames.contains(name)) { throw new AssertionError("Duplicate lambda name: " + name); diff --git a/substratevm/src/com.oracle.svm.jvmtiagentbase/src/com/oracle/svm/jvmtiagentbase/JNIHandleSet.java b/substratevm/src/com.oracle.svm.jvmtiagentbase/src/com/oracle/svm/jvmtiagentbase/JNIHandleSet.java index 761f97b272a1..14b1c4ddb2b4 100644 --- a/substratevm/src/com.oracle.svm.jvmtiagentbase/src/com/oracle/svm/jvmtiagentbase/JNIHandleSet.java +++ b/substratevm/src/com.oracle.svm.jvmtiagentbase/src/com/oracle/svm/jvmtiagentbase/JNIHandleSet.java @@ -24,19 +24,18 @@ */ package com.oracle.svm.jvmtiagentbase; -import static com.oracle.svm.core.util.VMError.guarantee; -import static com.oracle.svm.jni.JNIObjectHandles.nullHandle; - -import java.util.concurrent.locks.ReentrantLock; - -import org.graalvm.nativeimage.c.type.CTypeConversion; -import org.graalvm.word.WordFactory; - import com.oracle.svm.jni.JNIObjectHandles; import com.oracle.svm.jni.nativeapi.JNIEnvironment; import com.oracle.svm.jni.nativeapi.JNIFieldId; import com.oracle.svm.jni.nativeapi.JNIMethodId; import com.oracle.svm.jni.nativeapi.JNIObjectHandle; +import org.graalvm.nativeimage.c.type.CTypeConversion; +import org.graalvm.word.WordFactory; + +import java.util.concurrent.locks.ReentrantLock; + +import static com.oracle.svm.core.util.VMError.guarantee; +import static com.oracle.svm.jni.JNIObjectHandles.nullHandle; /** * Helps with creation and management of JNI handles for JVMTI agents. @@ -59,8 +58,10 @@ public abstract class JNIHandleSet { private boolean destroyed = false; final JNIMethodId javaLangClassGetName; + final JNIObjectHandle javaIoSerializable; public JNIHandleSet(JNIEnvironment env) { + javaIoSerializable = newClassGlobalRef(env, "java/io/Serializable"); JNIObjectHandle javaLangClass = findClass(env, "java/lang/Class"); javaLangClassGetName = getMethodId(env, javaLangClass, "getName", "()Ljava/lang/String;", false); } diff --git a/substratevm/src/com.oracle.svm.jvmtiagentbase/src/com/oracle/svm/jvmtiagentbase/Support.java b/substratevm/src/com.oracle.svm.jvmtiagentbase/src/com/oracle/svm/jvmtiagentbase/Support.java index cc364f38b1af..6682b41af9bc 100644 --- a/substratevm/src/com.oracle.svm.jvmtiagentbase/src/com/oracle/svm/jvmtiagentbase/Support.java +++ b/substratevm/src/com.oracle.svm.jvmtiagentbase/src/com/oracle/svm/jvmtiagentbase/Support.java @@ -158,6 +158,10 @@ public static CCharPointerHolder toCString(String s) { return CTypeConversion.toCString(s); } + public static boolean isSerializable(JNIEnvironment env, JNIObjectHandle serializeTargetClass) { + return jniFunctions().getIsAssignableFrom().invoke(env, serializeTargetClass, JvmtiAgentBase.singleton().handles().javaIoSerializable); + } + public static JNIObjectHandle getCallerClass(int depth) { return getMethodDeclaringClass(getCallerMethod(depth)); } diff --git a/substratevm/src/com.oracle.svm.reflect/src/com/oracle/svm/reflect/serialize/SerializationSupport.java b/substratevm/src/com.oracle.svm.reflect/src/com/oracle/svm/reflect/serialize/SerializationSupport.java index e306c5c9ba0c..afe9422d8208 100644 --- a/substratevm/src/com.oracle.svm.reflect/src/com/oracle/svm/reflect/serialize/SerializationSupport.java +++ b/substratevm/src/com.oracle.svm.reflect/src/com/oracle/svm/reflect/serialize/SerializationSupport.java @@ -26,12 +26,15 @@ package com.oracle.svm.reflect.serialize; import java.io.Serializable; +// Checkstyle: stop +import java.lang.invoke.SerializedLambda; import java.lang.reflect.Constructor; import java.lang.reflect.Modifier; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; +import org.graalvm.compiler.java.LambdaUtils; import org.graalvm.nativeimage.Platform; import org.graalvm.nativeimage.Platforms; @@ -124,18 +127,20 @@ public Object addConstructorAccessor(Class declaringClass, Class targetCon } @Override - public Object getSerializationConstructorAccessor(Class declaringClass, Class rawTargetConstructorClass) { - Class targetConstructorClass = Modifier.isAbstract(declaringClass.getModifiers()) ? stubConstructor.getDeclaringClass() : rawTargetConstructorClass; + public Object getSerializationConstructorAccessor(Class rawDeclaringClass, Class rawTargetConstructorClass) { + Class declaringClass = rawDeclaringClass; + + if (declaringClass.getName().contains(LambdaUtils.LAMBDA_CLASS_NAME_SUBSTRING)) { + declaringClass = SerializedLambda.class; + } + Class targetConstructorClass = Modifier.isAbstract(declaringClass.getModifiers()) ? stubConstructor.getDeclaringClass() : rawTargetConstructorClass; Object constructorAccessor = constructorAccessors.get(new SerializationLookupKey(declaringClass, targetConstructorClass)); if (constructorAccessor != null) { return constructorAccessor; } else { String targetConstructorClassName = targetConstructorClass.getName(); - if (targetConstructorClassName.contains("$$Lambda$")) { - throw VMError.unsupportedFeature("Can't serialize " + targetConstructorClassName + ". Lambda class serialization is currently not supported"); - } throw VMError.unsupportedFeature("SerializationConstructorAccessor class not found for declaringClass: " + declaringClass.getName() + " (targetConstructorClass: " + targetConstructorClassName + "). Usually adding " + declaringClass.getName() + " to serialization-config.json fixes the problem."); diff --git a/substratevm/src/com.oracle.svm.reflect/src/com/oracle/svm/reflect/serialize/hosted/SerializationFeature.java b/substratevm/src/com.oracle.svm.reflect/src/com/oracle/svm/reflect/serialize/hosted/SerializationFeature.java index 46154ebcbceb..433e1e52a7e2 100644 --- a/substratevm/src/com.oracle.svm.reflect/src/com/oracle/svm/reflect/serialize/hosted/SerializationFeature.java +++ b/substratevm/src/com.oracle.svm.reflect/src/com/oracle/svm/reflect/serialize/hosted/SerializationFeature.java @@ -25,32 +25,10 @@ */ package com.oracle.svm.reflect.serialize.hosted; -import static com.oracle.svm.reflect.serialize.hosted.SerializationFeature.println; -import static com.oracle.svm.reflect.serialize.hosted.SerializationFeature.warn; - -import java.io.Externalizable; -import java.io.ObjectOutputStream; -import java.io.ObjectStreamClass; -import java.io.ObjectStreamField; -import java.io.Serializable; -import java.lang.reflect.Constructor; -import java.lang.reflect.Field; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import jdk.vm.ci.meta.JavaKind; -import org.graalvm.nativeimage.ImageSingletons; -import org.graalvm.nativeimage.hosted.Feature; -import org.graalvm.nativeimage.hosted.RuntimeReflection; -import org.graalvm.nativeimage.impl.ConfigurationCondition; -import org.graalvm.nativeimage.impl.RuntimeSerializationSupport; +// Checkstyle: allow reflection +import com.oracle.graal.pointsto.phases.NoClassInitializationPlugin; +import com.oracle.graal.pointsto.util.GraalAccess; import com.oracle.svm.core.annotate.AutomaticFeature; import com.oracle.svm.core.configure.ConfigurationFile; import com.oracle.svm.core.configure.ConfigurationFiles; @@ -70,11 +48,58 @@ import com.oracle.svm.reflect.serialize.SerializationRegistry; import com.oracle.svm.reflect.serialize.SerializationSupport; import com.oracle.svm.util.ReflectionUtil; - import jdk.internal.reflect.ReflectionFactory; +import jdk.vm.ci.hotspot.HotSpotObjectConstant; +import jdk.vm.ci.meta.Constant; +import jdk.vm.ci.meta.JavaConstant; +import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.ResolvedJavaField; +import jdk.vm.ci.meta.ResolvedJavaMethod; +import jdk.vm.ci.meta.ResolvedJavaType; +import org.graalvm.compiler.debug.DebugContext; +import org.graalvm.compiler.graph.iterators.NodeIterable; +import org.graalvm.compiler.java.GraphBuilderPhase; +import org.graalvm.compiler.java.LambdaUtils; +import org.graalvm.compiler.nodes.ConstantNode; +import org.graalvm.compiler.nodes.StructuredGraph; +import org.graalvm.compiler.nodes.graphbuilderconf.GraphBuilderConfiguration; +import org.graalvm.compiler.nodes.graphbuilderconf.InvocationPlugins; +import org.graalvm.compiler.phases.OptimisticOptimizations; +import org.graalvm.compiler.phases.tiers.HighTierContext; +import org.graalvm.compiler.replacements.MethodHandlePlugin; +import org.graalvm.nativeimage.ImageSingletons; +import org.graalvm.nativeimage.hosted.Feature; +import org.graalvm.nativeimage.hosted.RuntimeReflection; +import org.graalvm.nativeimage.impl.ConfigurationCondition; +import org.graalvm.nativeimage.impl.RuntimeSerializationSupport; + +import java.io.Externalizable; +import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; +import java.io.ObjectStreamField; +import java.io.Serializable; +import java.lang.invoke.SerializedLambda; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.Member; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.oracle.svm.reflect.serialize.hosted.SerializationFeature.capturingClasses; +import static com.oracle.svm.reflect.serialize.hosted.SerializationFeature.println; +import static com.oracle.svm.reflect.serialize.hosted.SerializationFeature.warn; @AutomaticFeature public class SerializationFeature implements Feature { + static final HashSet> capturingClasses = new HashSet<>(); private SerializationBuilder serializationBuilder; private int loadedConfigurations; @@ -103,8 +128,108 @@ public void duringSetup(DuringSetupAccess a) { ConfigurationFile.SERIALIZATION.getFileName()); } + private static GraphBuilderConfiguration buildLambdaParserConfig() { + GraphBuilderConfiguration.Plugins plugins = new GraphBuilderConfiguration.Plugins(new InvocationPlugins()); + plugins.setClassInitializationPlugin(new NoClassInitializationPlugin()); + plugins.prependNodePlugin(new MethodHandlePlugin(GraalAccess.getOriginalProviders().getConstantReflection().getMethodHandleAccess(), false)); + return GraphBuilderConfiguration.getDefault(plugins).withEagerResolving(true); + } + + @SuppressWarnings("try") + private static StructuredGraph createMethodGraph(ResolvedJavaMethod method, GraphBuilderPhase lambdaParserPhase, DebugContext debug) { + StructuredGraph graph = new StructuredGraph.Builder(debug.getOptions(), debug).method(method).build(); + try (DebugContext.Scope ignored = debug.scope("ParsingToMaterializeLambdas")) { + HighTierContext context = new HighTierContext(GraalAccess.getOriginalProviders(), null, OptimisticOptimizations.NONE); + lambdaParserPhase.apply(graph, context); + } catch (Throwable e) { + throw debug.handle(e); + } + return graph; + } + + private static Class getLambdaClassFromMemberField(Constant constant) { + ResolvedJavaType constantType = GraalAccess.getOriginalProviders().getMetaAccess().lookupJavaType((JavaConstant) constant); + + if (constantType == null) { + return null; + } + + ResolvedJavaField[] fields = constantType.getInstanceFields(true); + ResolvedJavaField targetField = null; + for (ResolvedJavaField field : fields) { + if (field.getName().equals("member")) { + targetField = field; + break; + } + } + + if (targetField == null) { + return null; + } + + HotSpotObjectConstant fieldValue = (HotSpotObjectConstant) GraalAccess.getOriginalProviders().getConstantReflection().readFieldValue(targetField, (JavaConstant) constant); + Member memberField = GraalAccess.getOriginalProviders().getSnippetReflection().asObject(Member.class, fieldValue); + return memberField.getDeclaringClass(); + } + + private static Class getLambdaClassFromConstantNode(ConstantNode constantNode) { + Constant constant = constantNode.getValue(); + Class lambdaClass = getLambdaClassFromMemberField(constant); + + if (lambdaClass == null) { + return null; + } + + return lambdaClass.getName().contains(LambdaUtils.LAMBDA_CLASS_NAME_SUBSTRING) ? lambdaClass : null; + } + + private static void registerLambdasFromConstantNodesInGraph(StructuredGraph graph) { + NodeIterable constantNodes = ConstantNode.getConstantNodes(graph); + + for (ConstantNode cNode : constantNodes) { + Class lambdaClass = getLambdaClassFromConstantNode(cNode); + + if (lambdaClass != null) { + try { + Method serializeLambdaMethod = lambdaClass.getDeclaredMethod("writeReplace"); + RuntimeReflection.register(serializeLambdaMethod); + } catch (NoSuchMethodException e) { + throw VMError.shouldNotReachHere("Serializable lambda class must contain the writeReplace method."); + } + } + } + } + + @SuppressWarnings("try") + private static void registerLambdasFromMethod(ResolvedJavaMethod method, DebugContext debug) { + GraphBuilderPhase lambdaParserPhase = new GraphBuilderPhase(buildLambdaParserConfig()); + StructuredGraph graph = createMethodGraph(method, lambdaParserPhase, debug); + registerLambdasFromConstantNodesInGraph(graph); + } + @Override public void beforeAnalysis(BeforeAnalysisAccess access) { + FeatureImpl.BeforeAnalysisAccessImpl impl = (FeatureImpl.BeforeAnalysisAccessImpl) access; + + /* + * In order to serialize lambda classes we need to register proper methods for reflection. + * Since lambda names are not stable, we do not know which lambdas should be serialized. We + * simply register all the lambdas from capturing classes written in the serialization + * configuration file for serialization. In order to find all the lambdas from a class, we + * parse all the methods of the given class and find all the lambdas in them. + */ + for (Class clazz : capturingClasses) { + ResolvedJavaType clazzType = GraalAccess.getOriginalProviders().getMetaAccess().lookupJavaType(clazz); + List allMethods = new ArrayList<>(Arrays.asList(clazzType.getDeclaredMethods())); + allMethods.addAll(Arrays.asList(clazzType.getDeclaredConstructors())); + + for (ResolvedJavaMethod method : allMethods) { + if (method.hasBytecodes()) { + registerLambdasFromMethod(method, impl.getDebugContext()); + } + } + } + serializationBuilder.flushConditionalConfiguration(access); /* Ensure SharedSecrets.javaObjectInputStreamAccess is initialized before scanning. */ ((BeforeAnalysisAccessImpl) access).ensureInitialized("java.io.ObjectInputStream"); @@ -171,6 +296,14 @@ public void registerWithTargetConstructorClass(ConfigurationCondition condition, registerWithTargetConstructorClass(condition, typeResolver.resolveType(className), null); } + @Override + public void registerLambdaCapturingClass(ConfigurationCondition condition, String lambdaCapturingClassName) { + Class lambdaCapturingClass = typeResolver.resolveType(lambdaCapturingClassName); + if (lambdaCapturingClass != null) { + deniedClasses.put(lambdaCapturingClass, true); + } + } + @Override public void registerWithTargetConstructorClass(ConfigurationCondition condition, Class clazz, Class customTargetConstructorClazz) { if (clazz != null) { @@ -309,9 +442,21 @@ public void registerWithTargetConstructorClass(ConfigurationCondition condition, } } + @Override + public void registerLambdaCapturingClass(ConfigurationCondition condition, String lambdaCapturingClassName) { + Class serializationTargetClass = typeResolver.resolveType(lambdaCapturingClassName); + + registerConditionalConfiguration(condition, () -> { + capturingClasses.add(serializationTargetClass); + RuntimeReflection.register(serializationTargetClass); + }); + RuntimeReflection.register(ReflectionUtil.lookupMethod(true, serializationTargetClass, "$deserializeLambda$", SerializedLambda.class)); + } + @Override public void registerWithTargetConstructorClass(ConfigurationCondition condition, Class serializationTargetClass, Class customTargetConstructorClass) { abortIfSealed(); + if (!Serializable.class.isAssignableFrom(serializationTargetClass)) { println("Warning: Could not register " + serializationTargetClass.getName() + " for serialization as it does not implement Serializable."); } else if (denyRegistry.isAllowed(serializationTargetClass)) { @@ -421,7 +566,6 @@ Class addConstructorAccessor(Class serializationTargetClass, Class cust Class targetConstructorClass; if (Modifier.isAbstract(serializationTargetClass.getModifiers())) { targetConstructor = stubConstructor; - targetConstructorClass = targetConstructor.getDeclaringClass(); } else { if (customTargetConstructorClass == serializationTargetClass) { /* No custom constructor needed. Simply use existing no-arg constructor. */ @@ -437,8 +581,13 @@ Class addConstructorAccessor(Class serializationTargetClass, Class cust } } targetConstructor = newConstructorForSerialization(serializationTargetClass, customConstructorToCall); - targetConstructorClass = targetConstructor.getDeclaringClass(); + + if (targetConstructor == null) { + targetConstructor = newConstructorForSerialization(Object.class, customConstructorToCall); + } + } + targetConstructorClass = targetConstructor.getDeclaringClass(); serializationSupport.addConstructorAccessor(serializationTargetClass, targetConstructorClass, getConstructorAccessor(targetConstructor)); return targetConstructorClass; } diff --git a/substratevm/src/com.oracle.svm.test/src/META-INF/native-image/com.oracle.svm.test/native-image.properties b/substratevm/src/com.oracle.svm.test/src/META-INF/native-image/com.oracle.svm.test/native-image.properties index 475599568472..599867d19e66 100644 --- a/substratevm/src/com.oracle.svm.test/src/META-INF/native-image/com.oracle.svm.test/native-image.properties +++ b/substratevm/src/com.oracle.svm.test/src/META-INF/native-image/com.oracle.svm.test/native-image.properties @@ -1,9 +1,10 @@ Args= \ --initialize-at-run-time=com.oracle.svm.test \ - --initialize-at-build-time=com.oracle.svm.test.AbstractClassSerializationTest,com.oracle.svm.test.SerializationRegistrationTest \ + --initialize-at-build-time=com.oracle.svm.test.AbstractClassSerializationTest,com.oracle.svm.test.SerializationRegistrationTest,com.oracle.svm.test.LambdaClassSerializationTest,com.oracle.svm.test.LambdaClassDeserializationTest \ --features=com.oracle.svm.test.SerializationRegistrationTestFeature \ --features=com.oracle.svm.test.AbstractServiceLoaderTest$TestFeature \ --features=com.oracle.svm.test.NoProviderConstructorServiceLoaderTest$TestFeature \ --features=com.oracle.svm.test.NativeImageResourceFileSystemProviderTest$TestFeature \ -H:+AllowVMInspection \ - --add-exports=org.graalvm.nativeimage.builder/com.oracle.svm.core.containers=ALL-UNNAMED + --features=com.oracle.svm.test.AbstractServiceLoaderTest$TestFeature,com.oracle.svm.test.NoProviderConstructorServiceLoaderTest$TestFeature \ + --add-exports=org.graalvm.nativeimage.builder/com.oracle.svm.core.containers=ALL-UNNAMED \ No newline at end of file diff --git a/substratevm/src/com.oracle.svm.test/src/META-INF/native-image/serialization-config.json b/substratevm/src/com.oracle.svm.test/src/META-INF/native-image/serialization-config.json index 720c618e584a..e85ebe7c5387 100644 --- a/substratevm/src/com.oracle.svm.test/src/META-INF/native-image/serialization-config.json +++ b/substratevm/src/com.oracle.svm.test/src/META-INF/native-image/serialization-config.json @@ -1,14 +1,33 @@ -[ - { - "name":"java.util.HashMap" - }, - { - "name":"java.lang.String" - }, - { - "name":"java.lang.Number" - }, - { - "name":"java.lang.Integer" - } -] \ No newline at end of file +{ + "types":[ + { + "name":"java.lang.Object[]" + }, + { + "name":"java.lang.String" + }, + { + "name":"java.lang.invoke.SerializedLambda" + }, + { + "name":"com.oracle.svm.test.AbstractClassSerializationTest" + }, + { + "name":"java.lang.Integer" + }, + { + "name":"java.lang.Number" + }, + { + "name":"java.util.HashMap" + } + ], + "lambdaCapturingTypes":[ + { + "name":"com.oracle.svm.test.LambdaClassSerializationTest" + }, + { + "name":"com.oracle.svm.test.LambdaClassDeserializationTest$SerializeLambda" + } + ] +} diff --git a/substratevm/src/com.oracle.svm.test/src/com/oracle/svm/test/LambdaClassDeserializationTest.java b/substratevm/src/com.oracle.svm.test/src/com/oracle/svm/test/LambdaClassDeserializationTest.java new file mode 100644 index 000000000000..d07f87ec914f --- /dev/null +++ b/substratevm/src/com.oracle.svm.test/src/com/oracle/svm/test/LambdaClassDeserializationTest.java @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2021, 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.svm.test; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.function.Function; + +// By declaring and serializing lambda in one class and deserializing it in another, we can simulate situation where program only +// deserializes lambda class +public class LambdaClassDeserializationTest { + private ByteArrayOutputStream byteArrayOutputStream; + + private static class SerializeLambda { + @SuppressWarnings("unchecked") + public static Function createLambda() { + return (Function & Serializable) (x) -> "Value of parameter is " + x; + } + + public static void serialize(ByteArrayOutputStream byteArrayOutputStream, Serializable serializableObject) throws IOException { + ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream); + objectOutputStream.writeObject(serializableObject); + objectOutputStream.close(); + } + } + + private static class DeserializeLambda { + public static Object deserialize(ByteArrayOutputStream byteArrayOutputStream) throws IOException, ClassNotFoundException { + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray()); + ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream); + return objectInputStream.readObject(); + } + } + + @Test + public void testLambdaLambdaDeserialization() throws Exception { + byteArrayOutputStream = new ByteArrayOutputStream(); + + int n = 10; + + Function lambda = SerializeLambda.createLambda(); + String originalLambdaString = lambda.apply(n); + + SerializeLambda.serialize(byteArrayOutputStream, (Serializable) lambda); + + @SuppressWarnings("unchecked") + Function deserializedFn = (Function) DeserializeLambda.deserialize(byteArrayOutputStream); + + String deserializedLambdaString = deserializedFn.apply(n); + + Assert.assertEquals(originalLambdaString, deserializedLambdaString); + } +} diff --git a/substratevm/src/com.oracle.svm.test/src/com/oracle/svm/test/LambdaClassSerializationTest.java b/substratevm/src/com.oracle.svm.test/src/com/oracle/svm/test/LambdaClassSerializationTest.java new file mode 100644 index 000000000000..5948702cc83f --- /dev/null +++ b/substratevm/src/com.oracle.svm.test/src/com/oracle/svm/test/LambdaClassSerializationTest.java @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2021, 2021, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.svm.test; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.function.Function; + +public class LambdaClassSerializationTest { + private ByteArrayOutputStream byteArrayOutputStream; + + private static void serialize(ByteArrayOutputStream byteArrayOutputStream, Serializable serializableObject) throws IOException { + ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream); + objectOutputStream.writeObject(serializableObject); + objectOutputStream.close(); + } + + private static Object deserialize(ByteArrayOutputStream byteArrayOutputStream) throws IOException, ClassNotFoundException { + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray()); + ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream); + return objectInputStream.readObject(); + } + + @Test + public void testLambdaClassSerialization() throws Exception { + byteArrayOutputStream = new ByteArrayOutputStream(); + + int n = 10; + + @SuppressWarnings("unchecked") + Function function = (Function & Serializable) (x) -> "Value of parameter is " + x; + String originalLambdaString = function.apply(n); + + serialize(byteArrayOutputStream, (Serializable) function); + + @SuppressWarnings("unchecked") + Function deserializedFunction = (Function) deserialize(byteArrayOutputStream); + + String deserializedLambdaString = deserializedFunction.apply(n); + + Assert.assertEquals(originalLambdaString, deserializedLambdaString); + } +}