diff --git a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala index e383bf61f36d..4bb3c9444087 100644 --- a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala +++ b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala @@ -325,6 +325,13 @@ private sealed trait XSettings: val XprintSuspension: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprint-suspension", "Show when code is suspended until macros are compiled.") val Xprompt: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprompt", "Display a prompt after each error (debugging option).") val XreplDisableDisplay: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xrepl-disable-display", "Do not display definitions in REPL.") + val XreplInterruptInstrumentation: Setting[String] = StringSetting( + AdvancedSetting, + "Xrepl-interrupt-instrumentation", + "true|false|local", + "pass `false` to disable bytecode instrumentation for interrupt handling in REPL, or `local` to limit interrupt support to only REPL-defined classes", + "true" + ) val XverifySignatures: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xverify-signatures", "Verify generic signatures in generated bytecode.") val XignoreScala2Macros: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xignore-scala2-macros", "Ignore errors when compiling code that calls Scala2 macros, these will fail at runtime.") val XimportSuggestionTimeout: Setting[Int] = IntSetting(AdvancedSetting, "Ximport-suggestion-timeout", "Timeout (in ms) for searching for import suggestions when errors are reported.", 8000) diff --git a/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala b/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala index 816bac14ddd2..daab56b9fff3 100644 --- a/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala +++ b/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala @@ -35,7 +35,7 @@ class Interpreter(pos: SrcPos, classLoader0: ClassLoader)(using Context): val classLoader = if ctx.owner.topLevelClass.name.startsWith(str.REPL_SESSION_LINE) then - new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader0) + new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader0, "false") else classLoader0 /** Local variable environment */ @@ -204,7 +204,11 @@ class Interpreter(pos: SrcPos, classLoader0: ClassLoader)(using Context): } private def loadReplLineClass(moduleClass: Symbol): Class[?] = { - val lineClassloader = new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader) + val lineClassloader = new AbstractFileClassLoader( + ctx.settings.outputDir.value, + classLoader, + "false" + ) lineClassloader.loadClass(moduleClass.name.firstPart.toString) } diff --git a/compiler/src/dotty/tools/repl/AbstractFileClassLoader.scala b/compiler/src/dotty/tools/repl/AbstractFileClassLoader.scala index 1796a7dc68b5..c46f12e1ed6d 100644 --- a/compiler/src/dotty/tools/repl/AbstractFileClassLoader.scala +++ b/compiler/src/dotty/tools/repl/AbstractFileClassLoader.scala @@ -16,11 +16,12 @@ package repl import scala.language.unsafeNulls import io.AbstractFile +import dotty.tools.repl.ReplBytecodeInstrumentation import java.net.{URL, URLConnection, URLStreamHandler} import java.util.Collections -class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) extends ClassLoader(parent): +class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader, interruptInstrumentation: String) extends ClassLoader(parent): private def findAbstractFile(name: String) = root.lookupPath(name.split('/').toIndexedSeq, directory = false) // on JDK 20 the URL constructor we're using is deprecated, @@ -45,17 +46,61 @@ class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) exten val pathParts = name.split("[./]").toList for (dirPart <- pathParts.init) { file = file.lookupName(dirPart, true) - if (file == null) { - throw new ClassNotFoundException(name) - } + if (file == null) throw new ClassNotFoundException(name) } file = file.lookupName(pathParts.last+".class", false) - if (file == null) { - throw new ClassNotFoundException(name) - } + if (file == null) throw new ClassNotFoundException(name) + val bytes = file.toByteArray - defineClass(name, bytes, 0, bytes.length) + + if interruptInstrumentation != "false" then defineClassInstrumented(name, bytes) + else defineClass(name, bytes, 0, bytes.length) } - override def loadClass(name: String): Class[?] = try findClass(name) catch case _: ClassNotFoundException => super.loadClass(name) + def defineClassInstrumented(name: String, originalBytes: Array[Byte]) = { + val instrumentedBytes = ReplBytecodeInstrumentation.instrument(originalBytes) + defineClass(name, instrumentedBytes, 0, instrumentedBytes.length) + } + + override def loadClass(name: String): Class[?] = + if interruptInstrumentation == "false" || interruptInstrumentation == "local" + then return super.loadClass(name) + + val loaded = findLoadedClass(name) // Check if already loaded + if loaded != null then return loaded + + name match { // Don't instrument JDK classes or StopRepl + case s"java.$_" => super.loadClass(name) + case s"javax.$_" => super.loadClass(name) + case s"sun.$_" => super.loadClass(name) + case s"jdk.$_" => super.loadClass(name) + case "dotty.tools.repl.StopRepl" => + // Load StopRepl bytecode from parent but ensure each classloader gets its own copy + val classFileName = name.replace('.', '/') + ".class" + val is = Option(getParent.getResourceAsStream(classFileName)) + // Can't get as resource, use the classloader that loaded this AbstractFileClassLoader + // class itself, which must have access to StopRepl + .getOrElse(classOf[AbstractFileClassLoader].getClassLoader.getResourceAsStream(classFileName)) + + try + val bytes = is.readAllBytes() + defineClass(name, bytes, 0, bytes.length) + finally is.close() + + case _ => + try findClass(name) + catch case _: ClassNotFoundException => + // Not in REPL output, try to load from parent and instrument it + try + val resourceName = name.replace('.', '/') + ".class" + getParent.getResourceAsStream(resourceName) match { + case null => super.loadClass(resourceName) + case is => + try defineClassInstrumented(name, is.readAllBytes()) + finally is.close() + } + catch + case ex: Exception => super.loadClass(name) + } + end AbstractFileClassLoader diff --git a/compiler/src/dotty/tools/repl/Rendering.scala b/compiler/src/dotty/tools/repl/Rendering.scala index 314198a1bf97..a86ba12bb0a9 100644 --- a/compiler/src/dotty/tools/repl/Rendering.scala +++ b/compiler/src/dotty/tools/repl/Rendering.scala @@ -72,7 +72,11 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None): new java.net.URLClassLoader(compilerClasspath.toArray, baseClassLoader) } - myClassLoader = new AbstractFileClassLoader(ctx.settings.outputDir.value, parent) + myClassLoader = new AbstractFileClassLoader( + ctx.settings.outputDir.value, + parent, + ctx.settings.XreplInterruptInstrumentation.value + ) myClassLoader } diff --git a/compiler/src/dotty/tools/repl/ReplBytecodeInstrumentation.scala b/compiler/src/dotty/tools/repl/ReplBytecodeInstrumentation.scala new file mode 100644 index 000000000000..e0eae1310e11 --- /dev/null +++ b/compiler/src/dotty/tools/repl/ReplBytecodeInstrumentation.scala @@ -0,0 +1,75 @@ +package dotty.tools +package repl + +import scala.language.unsafeNulls + +import scala.tools.asm.* +import scala.tools.asm.Opcodes.* +import scala.tools.asm.tree.* +import scala.collection.JavaConverters.* +import java.util.concurrent.atomic.AtomicBoolean + +object ReplBytecodeInstrumentation: + /** Instrument bytecode to add checks to throw an exception if the REPL command is cancelled + */ + def instrument(originalBytes: Array[Byte]): Array[Byte] = + try + val cr = new ClassReader(originalBytes) + val cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES) + val instrumenter = new InstrumentClassVisitor(cw) + cr.accept(instrumenter, ClassReader.EXPAND_FRAMES) + cw.toByteArray + catch + case ex: Exception => originalBytes + + def setStopFlag(classLoader: ClassLoader, b: Boolean): Unit = + val cancelClassOpt = + try Some(classLoader.loadClass(classOf[dotty.tools.repl.StopRepl].getName)) + catch { + case _: java.lang.ClassNotFoundException => None + } + for(cancelClass <- cancelClassOpt) { + val setAllStopMethod = cancelClass.getDeclaredMethod("setStop", classOf[Boolean]) + setAllStopMethod.invoke(null, b.asInstanceOf[AnyRef]) + } + + private class InstrumentClassVisitor(cv: ClassVisitor) extends ClassVisitor(ASM9, cv): + + override def visitMethod( + access: Int, + name: String, + descriptor: String, + signature: String, + exceptions: Array[String] + ): MethodVisitor = + new InstrumentMethodVisitor(super.visitMethod(access, name, descriptor, signature, exceptions)) + + /** MethodVisitor that inserts stop checks at backward branches */ + private class InstrumentMethodVisitor(mv: MethodVisitor) extends MethodVisitor(ASM9, mv): + // Track labels we've seen to identify backward branches + private val seenLabels = scala.collection.mutable.Set[Label]() + + def addStopCheck() = mv.visitMethodInsn( + INVOKESTATIC, + classOf[dotty.tools.repl.StopRepl].getName.replace('.', '/'), + "throwIfReplStopped", + "()V", + false + ) + + override def visitCode(): Unit = + super.visitCode() + // Insert throwIfReplStopped() call at the start of the method + // to allow breaking out of deeply recursive methods like fib(99) + addStopCheck() + + override def visitLabel(label: Label): Unit = + seenLabels.add(label) + super.visitLabel(label) + + override def visitJumpInsn(opcode: Int, label: Label): Unit = + // Add throwIfReplStopped if this is a backward branch (jumping to a label we've already seen) + if seenLabels.contains(label) then addStopCheck() + super.visitJumpInsn(opcode, label) + +end ReplBytecodeInstrumentation diff --git a/compiler/src/dotty/tools/repl/ReplDriver.scala b/compiler/src/dotty/tools/repl/ReplDriver.scala index ffa1b648446d..5cd8fd0bf539 100644 --- a/compiler/src/dotty/tools/repl/ReplDriver.scala +++ b/compiler/src/dotty/tools/repl/ReplDriver.scala @@ -34,6 +34,7 @@ import dotty.tools.dotc.{CompilationUnit, Driver} import dotty.tools.dotc.config.CompilerCommand import dotty.tools.io.* import dotty.tools.repl.Rendering.showUser +import dotty.tools.repl.ReplBytecodeInstrumentation import dotty.tools.runner.ScalaClassLoader.* import org.jline.reader.* @@ -228,13 +229,20 @@ class ReplDriver(settings: Array[String], // Set up interrupt handler for command execution var firstCtrlCEntered = false val thread = Thread.currentThread() + + // Clear the stop flag before executing new code + ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), false) + val previousSignalHandler = terminal.handle( org.jline.terminal.Terminal.Signal.INT, (sig: org.jline.terminal.Terminal.Signal) => { if (!firstCtrlCEntered) { firstCtrlCEntered = true + // Set the stop flag to trigger throwIfReplStopped() in instrumented code + ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), true) + // Also interrupt the thread as a fallback for non-instrumented code thread.interrupt() - out.println("\nInterrupting running thread, Ctrl-C again to terminate the REPL Process") + out.println("\nAttempting to interrupt running thread with `Thread.interrupt`") } else { out.println("\nTerminating REPL Process...") System.exit(130) // Standard exit code for SIGINT @@ -592,7 +600,10 @@ class ReplDriver(settings: Array[String], val jarClassLoader = fromURLsParallelCapable( jarClassPath.asURLs, prevClassLoader) rendering.myClassLoader = new AbstractFileClassLoader( - prevOutputDir, jarClassLoader) + prevOutputDir, + jarClassLoader, + ctx.settings.XreplInterruptInstrumentation.value + ) out.println(s"Added '$path' to classpath.") } catch { diff --git a/compiler/src/dotty/tools/repl/ScriptEngine.scala b/compiler/src/dotty/tools/repl/ScriptEngine.scala index cce16000577f..490b60ed8818 100644 --- a/compiler/src/dotty/tools/repl/ScriptEngine.scala +++ b/compiler/src/dotty/tools/repl/ScriptEngine.scala @@ -24,8 +24,11 @@ class ScriptEngine extends AbstractScriptEngine { "-classpath", "", // Avoid the default "." "-usejavacp", "-color:never", - "-Xrepl-disable-display" + "-Xrepl-disable-display", + "-Xrepl-interrupt-instrumentation", + "false" ), Console.out, None) + private val rendering = new Rendering(Some(getClass.getClassLoader)) private var state: State = driver.initialState diff --git a/compiler/src/dotty/tools/repl/StopRepl.scala b/compiler/src/dotty/tools/repl/StopRepl.scala new file mode 100644 index 000000000000..17e7078a3774 --- /dev/null +++ b/compiler/src/dotty/tools/repl/StopRepl.scala @@ -0,0 +1,18 @@ +package dotty.tools.repl + +import scala.annotation.static + +class StopRepl + +object StopRepl { + // Needs to be volatile, otherwise changes to this may not get seen by other threads + // for arbitrarily long periods of time (minutes!) + @static @volatile private var stop: Boolean = false + + @static def setStop(n: Boolean): Unit = { stop = n } + + /** Check if execution should stop, and throw ThreadDeath if so */ + @static def throwIfReplStopped(): Unit = { + if (stop) throw new ThreadDeath() + } +} diff --git a/compiler/test/dotty/tools/repl/AbstractFileClassLoaderTest.scala b/compiler/test/dotty/tools/repl/AbstractFileClassLoaderTest.scala index b06c34950719..1175900864e9 100644 --- a/compiler/test/dotty/tools/repl/AbstractFileClassLoaderTest.scala +++ b/compiler/test/dotty/tools/repl/AbstractFileClassLoaderTest.scala @@ -50,13 +50,13 @@ class AbstractFileClassLoaderTest: @Test def afclGetsParent(): Unit = val p = new URLClassLoader(Array.empty[URL]) val d = new VirtualDirectory("vd", None) - val x = new AbstractFileClassLoader(d, p) + val x = new AbstractFileClassLoader(d, p, "false") assertSame(p, x.getParent) @Test def afclGetsResource(): Unit = val (fuzz, booz) = fuzzBuzzBooz booz.writeContent("hello, world") - val sut = new AbstractFileClassLoader(fuzz, NoClassLoader) + val sut = new AbstractFileClassLoader(fuzz, NoClassLoader, "false") val res = sut.getResource("buzz/booz.class") assertNotNull("Find buzz/booz.class", res) assertEquals("hello, world", slurp(res)) @@ -66,8 +66,8 @@ class AbstractFileClassLoaderTest: val (fuzz_, booz_) = fuzzBuzzBooz booz.writeContent("hello, world") booz_.writeContent("hello, world_") - val p = new AbstractFileClassLoader(fuzz, NoClassLoader) - val sut = new AbstractFileClassLoader(fuzz_, p) + val p = new AbstractFileClassLoader(fuzz, NoClassLoader, "false") + val sut = new AbstractFileClassLoader(fuzz_, p, "false") val res = sut.getResource("buzz/booz.class") assertNotNull("Find buzz/booz.class", res) assertEquals("hello, world", slurp(res)) @@ -78,7 +78,7 @@ class AbstractFileClassLoaderTest: val bass = fuzz.fileNamed("bass") booz.writeContent("hello, world") bass.writeContent("lo tone") - val sut = new AbstractFileClassLoader(fuzz, NoClassLoader) + val sut = new AbstractFileClassLoader(fuzz, NoClassLoader, "false") val res = sut.getResource("booz.class") assertNotNull(res) assertEquals("hello, world", slurp(res)) @@ -88,7 +88,7 @@ class AbstractFileClassLoaderTest: @Test def afclGetsResources(): Unit = val (fuzz, booz) = fuzzBuzzBooz booz.writeContent("hello, world") - val sut = new AbstractFileClassLoader(fuzz, NoClassLoader) + val sut = new AbstractFileClassLoader(fuzz, NoClassLoader, "false") val e = sut.getResources("buzz/booz.class") assertTrue("At least one buzz/booz.class", e.hasMoreElements) assertEquals("hello, world", slurp(e.nextElement)) @@ -99,8 +99,8 @@ class AbstractFileClassLoaderTest: val (fuzz_, booz_) = fuzzBuzzBooz booz.writeContent("hello, world") booz_.writeContent("hello, world_") - val p = new AbstractFileClassLoader(fuzz, NoClassLoader) - val x = new AbstractFileClassLoader(fuzz_, p) + val p = new AbstractFileClassLoader(fuzz, NoClassLoader, "false") + val x = new AbstractFileClassLoader(fuzz_, p, "false") val e = x.getResources("buzz/booz.class") assertTrue(e.hasMoreElements) assertEquals("hello, world", slurp(e.nextElement)) @@ -111,7 +111,7 @@ class AbstractFileClassLoaderTest: @Test def afclGetsResourceAsStream(): Unit = val (fuzz, booz) = fuzzBuzzBooz booz.writeContent("hello, world") - val x = new AbstractFileClassLoader(fuzz, NoClassLoader) + val x = new AbstractFileClassLoader(fuzz, NoClassLoader, "false") val r = x.getResourceAsStream("buzz/booz.class") assertNotNull(r) assertEquals("hello, world", closing(r)(is => Source.fromInputStream(is).mkString)) @@ -119,7 +119,7 @@ class AbstractFileClassLoaderTest: @Test def afclGetsClassBytes(): Unit = val (fuzz, booz) = fuzzBuzzBooz booz.writeContent("hello, world") - val sut = new AbstractFileClassLoader(fuzz, NoClassLoader) + val sut = new AbstractFileClassLoader(fuzz, NoClassLoader, "false") val b = sut.classBytes("buzz/booz.class") assertEquals("hello, world", new String(b, UTF8.charSet)) @@ -129,8 +129,8 @@ class AbstractFileClassLoaderTest: booz.writeContent("hello, world") booz_.writeContent("hello, world_") - val p = new AbstractFileClassLoader(fuzz, NoClassLoader) - val sut = new AbstractFileClassLoader(fuzz_, p) + val p = new AbstractFileClassLoader(fuzz, NoClassLoader, "false") + val sut = new AbstractFileClassLoader(fuzz_, p, "false") val b = sut.classBytes("buzz/booz.class") assertEquals("hello, world", new String(b, UTF8.charSet)) end AbstractFileClassLoaderTest diff --git a/staging/src/scala/quoted/staging/QuoteDriver.scala b/staging/src/scala/quoted/staging/QuoteDriver.scala index 82e91f7d7888..813d8d8c16be 100644 --- a/staging/src/scala/quoted/staging/QuoteDriver.scala +++ b/staging/src/scala/quoted/staging/QuoteDriver.scala @@ -61,7 +61,7 @@ private class QuoteDriver(appClassloader: ClassLoader) extends Driver: case Left(classname) => assert(!ctx.reporter.hasErrors) - val classLoader = new AbstractFileClassLoader(outDir, appClassloader) + val classLoader = new AbstractFileClassLoader(outDir, appClassloader, "false") val clazz = classLoader.loadClass(classname) val method = clazz.getMethod("apply")