2828import org .objectweb .asm .util .Printer ;
2929
3030import java .lang .reflect .Constructor ;
31+ import java .lang .reflect .Method ;
3132import java .net .MalformedURLException ;
3233import java .net .URL ;
3334import java .security .CodeSource ;
3435import java .security .SecureClassLoader ;
3536import java .security .cert .Certificate ;
37+ import java .util .Collections ;
38+ import java .util .HashMap ;
3639import java .util .Map ;
3740import java .util .concurrent .atomic .AtomicInteger ;
3841
@@ -89,16 +92,11 @@ final class Loader extends SecureClassLoader {
8992 */
9093 @ Override
9194 public Class <?> findClass (String name ) throws ClassNotFoundException {
92- if (scriptClass .getName ().equals (name )) {
93- return scriptClass ;
95+ Class <?> found = additionalClasses .get (name );
96+ if (found != null ) {
97+ return found ;
9498 }
95- if (factoryClass != null && factoryClass .getName ().equals (name )) {
96- return factoryClass ;
97- }
98- if (statefulFactoryClass != null && statefulFactoryClass .getName ().equals (name )) {
99- return statefulFactoryClass ;
100- }
101- Class <?> found = painlessLookup .canonicalTypeNameToType (name .replace ('$' , '.' ));
99+ found = painlessLookup .canonicalTypeNameToType (name .replace ('$' , '.' ));
102100
103101 return found != null ? found : super .findClass (name );
104102 }
@@ -156,19 +154,14 @@ public Loader createLoader(ClassLoader parent) {
156154 private final Class <?> scriptClass ;
157155
158156 /**
159- * The class/interface to create the {@code scriptClass} instance.
160- */
161- private final Class <?> factoryClass ;
162-
163- /**
164- * An optional class/interface to create the {@code factoryClass} instance.
157+ * The whitelist the script will use.
165158 */
166- private final Class <?> statefulFactoryClass ;
159+ private final PainlessLookup painlessLookup ;
167160
168161 /**
169- * The whitelist the script will use .
162+ * Classes that do not exist in the lookup, but are needed by the script factories .
170163 */
171- private final PainlessLookup painlessLookup ;
164+ private final Map < String , Class <?>> additionalClasses ;
172165
173166 /**
174167 * Standard constructor.
@@ -179,9 +172,36 @@ public Loader createLoader(ClassLoader parent) {
179172 */
180173 Compiler (Class <?> scriptClass , Class <?> factoryClass , Class <?> statefulFactoryClass , PainlessLookup painlessLookup ) {
181174 this .scriptClass = scriptClass ;
182- this .factoryClass = factoryClass ;
183- this .statefulFactoryClass = statefulFactoryClass ;
184175 this .painlessLookup = painlessLookup ;
176+ Map <String , Class <?>> additionalClasses = new HashMap <>();
177+ additionalClasses .put (scriptClass .getName (), scriptClass );
178+ addFactoryMethod (additionalClasses , factoryClass , "newInstance" );
179+ addFactoryMethod (additionalClasses , statefulFactoryClass , "newFactory" );
180+ addFactoryMethod (additionalClasses , statefulFactoryClass , "newInstance" );
181+ this .additionalClasses = Collections .unmodifiableMap (additionalClasses );
182+ }
183+
184+ private static void addFactoryMethod (Map <String , Class <?>> additionalClasses , Class <?> factoryClass , String methodName ) {
185+ if (factoryClass == null ) {
186+ return ;
187+ }
188+
189+ Method factoryMethod = null ;
190+ for (Method method : factoryClass .getMethods ()) {
191+ if (methodName .equals (method .getName ())) {
192+ factoryMethod = method ;
193+ break ;
194+ }
195+ }
196+ if (factoryMethod == null ) {
197+ return ;
198+ }
199+
200+ additionalClasses .put (factoryClass .getName (), factoryClass );
201+ for (int i = 0 ; i < factoryMethod .getParameterTypes ().length ; ++i ) {
202+ Class <?> parameterClazz = factoryMethod .getParameterTypes ()[i ];
203+ additionalClasses .put (parameterClazz .getName (), parameterClazz );
204+ }
185205 }
186206
187207 /**
0 commit comments