diff --git a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/JNITestingBackdoor.java b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/JNITestingBackdoor.java index e6c1c2642303..91e504149a51 100644 --- a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/JNITestingBackdoor.java +++ b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/JNITestingBackdoor.java @@ -44,7 +44,7 @@ public static int getThreadLocalPinnedObjectCount() { } public static long getMethodID(Class clazz, String name, String signature, boolean isStatic) { - return JNIReflectionDictionary.singleton().getMethodID(clazz, name, signature, isStatic).rawValue(); + return JNIReflectionDictionary.getMethodID(clazz, name, signature, isStatic).rawValue(); } public static int getThreadLocalOwnedMonitorsCount() { diff --git a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/access/JNIReflectionDictionary.java b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/access/JNIReflectionDictionary.java index 35d4f96ead00..ef632530708b 100644 --- a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/access/JNIReflectionDictionary.java +++ b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/access/JNIReflectionDictionary.java @@ -28,6 +28,7 @@ import static com.oracle.svm.core.SubstrateOptions.JNIVerboseLookupErrors; import java.io.PrintStream; +import java.util.EnumSet; import java.util.Map; import java.util.function.Function; @@ -46,6 +47,9 @@ import com.oracle.svm.core.jni.MissingJNIRegistrationUtils; import com.oracle.svm.core.jni.headers.JNIFieldId; import com.oracle.svm.core.jni.headers.JNIMethodId; +import com.oracle.svm.core.layeredimagesingleton.LayeredImageSingletonBuilderFlags; +import com.oracle.svm.core.layeredimagesingleton.MultiLayeredImageSingleton; +import com.oracle.svm.core.layeredimagesingleton.UnsavedSingleton; import com.oracle.svm.core.log.Log; import com.oracle.svm.core.snippets.KnownIntrinsics; import com.oracle.svm.core.util.ImageHeapMap; @@ -61,7 +65,7 @@ /** * Provides JNI access to predetermined classes, methods and fields at runtime. */ -public final class JNIReflectionDictionary { +public final class JNIReflectionDictionary implements MultiLayeredImageSingleton, UnsavedSingleton { /** * Enables lookups with {@link WrappedAsciiCString}, which avoids many unnecessary character set * conversions and allocations. @@ -91,10 +95,15 @@ public static void create() { ImageSingletons.add(JNIReflectionDictionary.class, new JNIReflectionDictionary()); } + @Platforms(HOSTED_ONLY.class) public static JNIReflectionDictionary singleton() { return ImageSingletons.lookup(JNIReflectionDictionary.class); } + private static JNIReflectionDictionary[] layeredSingletons() { + return MultiLayeredImageSingleton.getAllLayers(JNIReflectionDictionary.class); + } + private final EconomicMap classesByName = ImageHeapMap.create(WRAPPED_CSTRING_EQUIVALENCE); private final EconomicMap, JNIAccessibleClass> classesByClassObject = ImageHeapMap.create(); private final EconomicMap nativeLinkages = ImageHeapMap.create(); @@ -102,36 +111,40 @@ public static JNIReflectionDictionary singleton() { private JNIReflectionDictionary() { } - private void dump(boolean condition, String label) { + private static void dump(boolean condition, String label) { if (JNIVerboseLookupErrors.getValue() && condition) { - PrintStream ps = Log.logStream(); - ps.println(label); - ps.println(" classesByName:"); - MapCursor nameCursor = classesByName.getEntries(); - while (nameCursor.advance()) { - ps.print(" "); - ps.println(nameCursor.getKey()); - JNIAccessibleClass clazz = nameCursor.getValue(); - ps.println(" methods:"); - MapCursor methodsCursor = clazz.getMethods(); - while (methodsCursor.advance()) { - ps.print(" "); - ps.print(methodsCursor.getKey().getName()); - ps.println(methodsCursor.getKey().getSignature()); + int layerNum = 0; + for (var dictionary : layeredSingletons()) { + PrintStream ps = Log.logStream(); + ps.println("Layer " + layerNum); + ps.println(label); + ps.println(" classesByName:"); + MapCursor nameCursor = dictionary.classesByName.getEntries(); + while (nameCursor.advance()) { + ps.print(" "); + ps.println(nameCursor.getKey()); + JNIAccessibleClass clazz = nameCursor.getValue(); + ps.println(" methods:"); + MapCursor methodsCursor = clazz.getMethods(); + while (methodsCursor.advance()) { + ps.print(" "); + ps.print(methodsCursor.getKey().getName()); + ps.println(methodsCursor.getKey().getSignature()); + } + ps.println(" fields:"); + UnmodifiableMapCursor fieldsCursor = clazz.getFields(); + while (fieldsCursor.advance()) { + ps.print(" "); + ps.println(fieldsCursor.getKey()); + } } - ps.println(" fields:"); - UnmodifiableMapCursor fieldsCursor = clazz.getFields(); - while (fieldsCursor.advance()) { - ps.print(" "); - ps.println(fieldsCursor.getKey()); - } - } - ps.println(" classesByClassObject:"); - MapCursor, JNIAccessibleClass> cursor = classesByClassObject.getEntries(); - while (cursor.advance()) { - ps.print(" "); - ps.println(cursor.getKey()); + ps.println(" classesByClassObject:"); + MapCursor, JNIAccessibleClass> cursor = dictionary.classesByClassObject.getEntries(); + while (cursor.advance()) { + ps.print(" "); + ps.println(cursor.getKey()); + } } } } @@ -151,6 +164,7 @@ public JNIAccessibleClass addClassIfAbsent(Class classObj, Function, return classesByClassObject.get(classObj); } + @Platforms(HOSTED_ONLY.class) public void addNegativeClassLookupIfAbsent(String typeName) { String internalName = MetaUtil.toInternalName(typeName); String queryName = internalName.startsWith("L") ? internalName.substring(1, internalName.length() - 1) : internalName; @@ -162,15 +176,21 @@ public void addLinkages(Map linkages) { nativeLinkages.putAll(EconomicMap.wrapMap(linkages)); } + @Platforms(HOSTED_ONLY.class) public Iterable getClasses() { return classesByClassObject.getValues(); } - public Class getClassObjectByName(CharSequence name) { - JNIAccessibleClass clazz = classesByName.get(name); - clazz = checkClass(clazz, name); - dump(clazz == null, "getClassObjectByName"); - return (clazz != null) ? clazz.getClassObject() : null; + public static Class getClassObjectByName(CharSequence name) { + for (var dictionary : layeredSingletons()) { + JNIAccessibleClass clazz = dictionary.classesByName.get(name); + clazz = checkClass(clazz, name); + if (clazz != null) { + return clazz.getClassObject(); + } + } + dump(true, "getClassObjectByName"); + return null; } private static JNIAccessibleClass checkClass(JNIAccessibleClass clazz, CharSequence name) { @@ -192,20 +212,28 @@ private static JNIAccessibleClass checkClass(JNIAccessibleClass clazz, CharSeque * method * @return the linkage for the native method or {@code null} if no linkage exists */ - public JNINativeLinkage getLinkage(CharSequence declaringClass, CharSequence name, CharSequence descriptor) { + public static JNINativeLinkage getLinkage(CharSequence declaringClass, CharSequence name, CharSequence descriptor) { JNINativeLinkage key = new JNINativeLinkage(declaringClass, name, descriptor); - return nativeLinkages.get(key); + for (var dictionary : layeredSingletons()) { + var linkage = dictionary.nativeLinkages.get(key); + if (linkage != null) { + return linkage; + } + } + return null; } - public void unsetEntryPoints(String declaringClass) { - for (JNINativeLinkage linkage : nativeLinkages.getKeys()) { - if (declaringClass.equals(linkage.getDeclaringClassName())) { - linkage.unsetEntryPoint(); + public static void unsetEntryPoints(String declaringClass) { + for (var dictionary : layeredSingletons()) { + for (JNINativeLinkage linkage : dictionary.nativeLinkages.getKeys()) { + if (declaringClass.equals(linkage.getDeclaringClassName())) { + linkage.unsetEntryPoint(); + } } } } - private JNIAccessibleMethod findMethod(Class clazz, JNIAccessibleMethodDescriptor descriptor, String dumpLabel) { + private static JNIAccessibleMethod findMethod(Class clazz, JNIAccessibleMethodDescriptor descriptor, String dumpLabel) { JNIAccessibleMethod method = getDeclaredMethod(clazz, descriptor, dumpLabel); if (descriptor.isConstructor() || descriptor.isClassInitializer()) { // never recurse return method; @@ -220,7 +248,7 @@ private JNIAccessibleMethod findMethod(Class clazz, JNIAccessibleMethodDescri return method; } - private JNIAccessibleMethod findSuperinterfaceMethod(Class clazz, JNIAccessibleMethodDescriptor descriptor) { + private static JNIAccessibleMethod findSuperinterfaceMethod(Class clazz, JNIAccessibleMethodDescriptor descriptor) { for (Class parent : clazz.getInterfaces()) { JNIAccessibleMethod method = getDeclaredMethod(parent, descriptor, null); if (method == null) { @@ -234,23 +262,29 @@ private JNIAccessibleMethod findSuperinterfaceMethod(Class clazz, JNIAccessib return null; } - public JNIMethodId getDeclaredMethodID(Class classObject, JNIAccessibleMethodDescriptor descriptor, boolean isStatic) { + public static JNIMethodId getDeclaredMethodID(Class classObject, JNIAccessibleMethodDescriptor descriptor, boolean isStatic) { JNIAccessibleMethod method = getDeclaredMethod(classObject, descriptor, "getDeclaredMethodID"); boolean match = (method != null && method.isStatic() == isStatic); return toMethodID(match ? method : null); } - private JNIAccessibleMethod getDeclaredMethod(Class classObject, JNIAccessibleMethodDescriptor descriptor, String dumpLabel) { - JNIAccessibleClass clazz = classesByClassObject.get(classObject); - dump(clazz == null && dumpLabel != null, dumpLabel); - JNIAccessibleMethod method = null; - if (clazz != null) { - method = clazz.getMethod(descriptor); + private static JNIAccessibleMethod getDeclaredMethod(Class classObject, JNIAccessibleMethodDescriptor descriptor, String dumpLabel) { + boolean foundClass = false; + for (var dictionary : layeredSingletons()) { + JNIAccessibleClass clazz = dictionary.classesByClassObject.get(classObject); + if (clazz != null) { + foundClass = true; + JNIAccessibleMethod method = clazz.getMethod(descriptor); + if (method != null) { + return method; + } + } } - return method; + dump(!foundClass && dumpLabel != null, dumpLabel); + return null; } - public JNIMethodId getMethodID(Class classObject, CharSequence name, CharSequence signature, boolean isStatic) { + public static JNIMethodId getMethodID(Class classObject, CharSequence name, CharSequence signature, boolean isStatic) { JNIAccessibleMethod method = findMethod(classObject, new JNIAccessibleMethodDescriptor(name, signature), "getMethodID"); method = checkMethod(method, classObject, name, signature); boolean match = (method != null && method.isStatic() == isStatic && method.isDiscoverableIn(classObject)); @@ -289,25 +323,29 @@ private static JNIAccessibleMethod checkMethod(JNIAccessibleMethod method, Class return method; } - private JNIAccessibleField getDeclaredField(Class classObject, CharSequence name, boolean isStatic, String dumpLabel) { - JNIAccessibleClass clazz = classesByClassObject.get(classObject); - dump(clazz == null && dumpLabel != null, dumpLabel); - if (clazz != null) { - JNIAccessibleField field = clazz.getField(name); - if (field != null && (field.isStatic() == isStatic || field.isNegative())) { - return field; + private static JNIAccessibleField getDeclaredField(Class classObject, CharSequence name, boolean isStatic, String dumpLabel) { + boolean foundClass = false; + for (var dictionary : layeredSingletons()) { + JNIAccessibleClass clazz = dictionary.classesByClassObject.get(classObject); + if (clazz != null) { + foundClass = true; + JNIAccessibleField field = clazz.getField(name); + if (field != null && (field.isStatic() == isStatic || field.isNegative())) { + return field; + } } } + dump(!foundClass && dumpLabel != null, dumpLabel); return null; } - public JNIFieldId getDeclaredFieldID(Class classObject, String name, boolean isStatic) { + public static JNIFieldId getDeclaredFieldID(Class classObject, String name, boolean isStatic) { JNIAccessibleField field = getDeclaredField(classObject, name, isStatic, "getDeclaredFieldID"); field = checkField(field, classObject, name); return (field != null) ? field.getId() : Word.nullPointer(); } - private JNIAccessibleField findField(Class clazz, CharSequence name, boolean isStatic, String dumpLabel) { + private static JNIAccessibleField findField(Class clazz, CharSequence name, boolean isStatic, String dumpLabel) { // Lookup according to JVM spec 5.4.3.2: local fields, superinterfaces, superclasses JNIAccessibleField field = getDeclaredField(clazz, name, isStatic, dumpLabel); if (field == null && isStatic) { @@ -319,7 +357,7 @@ private JNIAccessibleField findField(Class clazz, CharSequence name, boolean return field; } - private JNIAccessibleField findSuperinterfaceField(Class clazz, CharSequence name) { + private static JNIAccessibleField findSuperinterfaceField(Class clazz, CharSequence name) { for (Class parent : clazz.getInterfaces()) { JNIAccessibleField field = getDeclaredField(parent, name, true, null); if (field == null) { @@ -332,21 +370,23 @@ private JNIAccessibleField findSuperinterfaceField(Class clazz, CharSequence return null; } - public JNIFieldId getFieldID(Class clazz, CharSequence name, boolean isStatic) { + public static JNIFieldId getFieldID(Class clazz, CharSequence name, boolean isStatic) { JNIAccessibleField field = findField(clazz, name, isStatic, "getFieldID"); field = checkField(field, clazz, name); return (field != null && field.isDiscoverableIn(clazz)) ? field.getId() : Word.nullPointer(); } - public String getFieldNameByID(Class classObject, JNIFieldId id) { - JNIAccessibleClass clazz = classesByClassObject.get(classObject); - if (clazz != null) { - UnmodifiableMapCursor fieldsCursor = clazz.getFields(); - while (fieldsCursor.advance()) { - JNIAccessibleField field = fieldsCursor.getValue(); - if (id.equal(field.getId())) { - VMError.guarantee(!field.isNegative(), "Existing fields can't correspond to a negative query"); - return (String) fieldsCursor.getKey(); + public static String getFieldNameByID(Class classObject, JNIFieldId id) { + for (var dictionary : layeredSingletons()) { + JNIAccessibleClass clazz = dictionary.classesByClassObject.get(classObject); + if (clazz != null) { + UnmodifiableMapCursor fieldsCursor = clazz.getFields(); + while (fieldsCursor.advance()) { + JNIAccessibleField field = fieldsCursor.getValue(); + if (id.equal(field.getId())) { + VMError.guarantee(!field.isNegative(), "Existing fields can't correspond to a negative query"); + return (String) fieldsCursor.getKey(); + } } } } @@ -375,4 +415,8 @@ public static JNIAccessibleMethodDescriptor getMethodDescriptor(JNIAccessibleMet return null; } + @Override + public EnumSet getImageBuilderFlags() { + return LayeredImageSingletonBuilderFlags.ALL_ACCESS; + } } diff --git a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/functions/JNIFunctions.java b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/functions/JNIFunctions.java index 3fa74785a3e1..5a754871fcaa 100644 --- a/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/functions/JNIFunctions.java +++ b/substratevm/src/com.oracle.svm.core/src/com/oracle/svm/core/jni/functions/JNIFunctions.java @@ -34,7 +34,6 @@ import java.nio.ByteBuffer; import java.util.Arrays; -import jdk.graal.compiler.word.Word; import org.graalvm.nativeimage.ImageSingletons; import org.graalvm.nativeimage.IsolateThread; import org.graalvm.nativeimage.LogHandler; @@ -129,6 +128,7 @@ import jdk.graal.compiler.core.common.SuppressFBWarnings; import jdk.graal.compiler.nodes.java.ArrayLengthNode; import jdk.graal.compiler.serviceprovider.JavaVersionUtil; +import jdk.graal.compiler.word.Word; import jdk.internal.misc.Unsafe; import jdk.vm.ci.meta.JavaKind; import jdk.vm.ci.meta.MetaUtil; @@ -366,7 +366,7 @@ static JNIObjectHandle FindClass(JNIEnvironment env, CCharPointer cname) { throw new NoClassDefFoundError("Class name is either null or invalid UTF-8 string"); } - Class clazz = JNIReflectionDictionary.singleton().getClassObjectByName(name); + Class clazz = JNIReflectionDictionary.getClassObjectByName(name); if (clazz == null) { throw new NoClassDefFoundError(name.toString()); } @@ -400,7 +400,7 @@ static int RegisterNatives(JNIEnvironment env, JNIObjectHandle hclazz, JNINative CFunctionPointer fnPtr = entry.fnPtr(); String declaringClass = MetaUtil.toInternalName(clazz.getName()); - JNINativeLinkage linkage = JNIReflectionDictionary.singleton().getLinkage(declaringClass, name, signature); + JNINativeLinkage linkage = JNIReflectionDictionary.getLinkage(declaringClass, name, signature); if (linkage != null) { linkage.setEntryPoint(fnPtr); } else { @@ -427,7 +427,7 @@ static int RegisterNatives(JNIEnvironment env, JNIObjectHandle hclazz, JNINative static int UnregisterNatives(JNIEnvironment env, JNIObjectHandle hclazz) { Class clazz = JNIObjectHandles.getObject(hclazz); String internalName = MetaUtil.toInternalName(clazz.getName()); - JNIReflectionDictionary.singleton().unsetEntryPoints(internalName); + JNIReflectionDictionary.unsetEntryPoints(internalName); return JNIErrors.JNI_OK(); } @@ -952,7 +952,7 @@ static JNIFieldId FromReflectedField(JNIEnvironment env, JNIObjectHandle fieldHa Field obj = JNIObjectHandles.getObject(fieldHandle); if (obj != null) { boolean isStatic = Modifier.isStatic(obj.getModifiers()); - fieldId = JNIReflectionDictionary.singleton().getDeclaredFieldID(obj.getDeclaringClass(), obj.getName(), isStatic); + fieldId = JNIReflectionDictionary.getDeclaredFieldID(obj.getDeclaringClass(), obj.getName(), isStatic); } return fieldId; } @@ -966,7 +966,7 @@ static JNIObjectHandle ToReflectedField(JNIEnvironment env, JNIObjectHandle clas Field field = null; Class clazz = JNIObjectHandles.getObject(classHandle); if (clazz != null) { - String name = JNIReflectionDictionary.singleton().getFieldNameByID(clazz, fieldId); + String name = JNIReflectionDictionary.getFieldNameByID(clazz, fieldId); if (name != null) { try { field = clazz.getDeclaredField(name); @@ -989,7 +989,7 @@ static JNIMethodId FromReflectedMethod(JNIEnvironment env, JNIObjectHandle metho if (method != null) { boolean isStatic = Modifier.isStatic(method.getModifiers()); JNIAccessibleMethodDescriptor descriptor = JNIAccessibleMethodDescriptor.of(method); - methodId = JNIReflectionDictionary.singleton().getDeclaredMethodID(method.getDeclaringClass(), descriptor, isStatic); + methodId = JNIReflectionDictionary.getDeclaredMethodID(method.getDeclaringClass(), descriptor, isStatic); } return methodId; } @@ -1839,10 +1839,10 @@ static JNIMethodId getMethodID(JNIObjectHandle hclazz, CCharPointer cname, CChar } private static JNIMethodId getMethodID(Class clazz, CharSequence name, CharSequence signature, boolean isStatic) { - JNIMethodId methodID = JNIReflectionDictionary.singleton().getMethodID(clazz, name, signature, isStatic); + JNIMethodId methodID = JNIReflectionDictionary.getMethodID(clazz, name, signature, isStatic); if (methodID.isNull()) { String message = clazz.getName() + "." + name + signature; - JNIMethodId candidate = JNIReflectionDictionary.singleton().getMethodID(clazz, name, signature, !isStatic); + JNIMethodId candidate = JNIReflectionDictionary.getMethodID(clazz, name, signature, !isStatic); if (candidate.isNonNull()) { if (isStatic) { message += " (found matching non-static method that would be returned by GetMethodID)"; @@ -1864,7 +1864,7 @@ static JNIFieldId getFieldID(JNIObjectHandle hclazz, CCharPointer cname, CCharPo throw new NoSuchFieldError("Field name is either null or invalid UTF-8 string"); } - JNIFieldId fieldID = JNIReflectionDictionary.singleton().getFieldID(clazz, name, isStatic); + JNIFieldId fieldID = JNIReflectionDictionary.getFieldID(clazz, name, isStatic); if (fieldID.isNull()) { throw new NoSuchFieldError(clazz.getName() + '.' + name); }