Skip to content
Closed
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ object Contexts {
/** Sourcefile corresponding to given abstract file, memoized */
def getSource(file: AbstractFile, codec: => Codec = Codec(settings.encoding.value)) = {
util.Stats.record("Context.getSource")
base.sources.getOrElseUpdate(file, new SourceFile(file, codec))
base.sources.getOrElseUpdate(file, SourceFile(file, codec))
}

/** SourceFile with given path name, memoized */
Expand Down
44 changes: 40 additions & 4 deletions compiler/src/dotty/tools/dotc/util/SourceFile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,36 @@ object ScriptSourceFile {
@sharable private val headerPattern = Pattern.compile("""^(::)?!#.*(\r|\n|\r\n)""", Pattern.MULTILINE)
private val headerStarts = List("#!", "::#!")

/** Return true if has a script header */
def hasScriptHeader(content: Array[Char]): Boolean = {
headerStarts exists (content startsWith _)
}

def apply(file: AbstractFile, content: Array[Char]): SourceFile = {
/** Length of the script header from the given content, if there is one.
* The header begins with "#!" or "::#!" and ends with a line starting
* with "!#" or "::!#".
* The header begins with "#!" or "::#!" and is either a single line,
* or it ends with a line starting with "!#" or "::!#", if present.
*/
val headerLength =
if (headerStarts exists (content startsWith _)) {
// convert initial hash-bang line to a comment
val matcher = headerPattern matcher content.mkString
if (matcher.find) matcher.end
else throw new IOException("script file does not close its header with !# or ::!#")
else content.indexOf('\n') // end of first line
}
else 0
new SourceFile(file, content drop headerLength) {

// overwrite hash-bang lines with all spaces
val hashBangLines = content.take(headerLength).mkString.split("\\r?\\n")
if hashBangLines.nonEmpty then
for i <- 0 until headerLength do
content(i) match {
case '\r' | '\n' =>
case _ =>
content(i) = ' '
}

new SourceFile(file, content) {
override val underlying = new SourceFile(this.file, this.content)
}
}
Expand Down Expand Up @@ -245,6 +262,25 @@ object SourceFile {
else
sourcePath.toString
}

/** Return true if file is a script:
* if filename extension is not .scala and has a script header.
*/
def isScript(file: AbstractFile, content: Array[Char]): Boolean =
if file.hasExtension(".scala") then
false
else
ScriptSourceFile.hasScriptHeader(content)

def apply(file: AbstractFile, codec: Codec): SourceFile =
// see note above re: Files.exists is remarkably slow
val chars = try new String(file.toByteArray, codec.charSet).toCharArray
catch case _: java.nio.file.NoSuchFileException => Array[Char]()
if isScript(file, chars) then
ScriptSourceFile(file, chars)
else
new SourceFile(file, chars)

}

@sharable object NoSource extends SourceFile(NoAbstractFile, Array[Char]()) {
Expand Down
125 changes: 118 additions & 7 deletions compiler/src/dotty/tools/scripting/Main.scala
Original file line number Diff line number Diff line change
@@ -1,22 +1,133 @@
package dotty.tools.scripting

import java.io.File
import java.nio.file.Path
import java.net.URLClassLoader
import java.lang.reflect.{ Modifier, Method }

/** Main entry point to the Scripting execution engine */
object Main:
/** All arguments before -script <target_script> are compiler arguments.
All arguments afterwards are script arguments.*/
def distinguishArgs(args: Array[String]): (Array[String], File, Array[String]) =
val (compilerArgs, rest) = args.splitAt(args.indexOf("-script"))
private def distinguishArgs(args: Array[String]): (Array[String], File, Array[String], Boolean) =
// NOTE: if -script <scriptName> not present, quit with error.
val (leftArgs, rest) = args.splitAt(args.indexOf("-script"))
if( rest.size < 2 ) then
sys.error(s"missing: -script <scriptName>")

val file = File(rest(1))
val scriptArgs = rest.drop(2)
(compilerArgs, file, scriptArgs)
var saveJar = false
val compilerArgs = leftArgs.filter {
case "-save" | "-savecompiled" =>
saveJar = true
false
case _ =>
true
}
(compilerArgs, file, scriptArgs, saveJar)
end distinguishArgs

def main(args: Array[String]): Unit =
val (compilerArgs, scriptFile, scriptArgs) = distinguishArgs(args)
try ScriptingDriver(compilerArgs, scriptFile, scriptArgs).compileAndRun()
val (compilerArgs, scriptFile, scriptArgs, saveJar) = distinguishArgs(args)
try ScriptingDriver(compilerArgs, scriptFile, scriptArgs).compileAndRun { (outDir:Path, classpath:String) =>
val classFiles = outDir.toFile.listFiles.toList match {
case Nil => sys.error(s"no files below [$outDir]")
case list => list
}

val (mainClassName, mainMethod) = detectMainClassAndMethod(outDir, classpath, scriptFile)

if saveJar then
// write a standalone jar to the script parent directory
writeJarfile(outDir, scriptFile, scriptArgs, classpath, mainClassName)

// invoke the compiled script main method
mainMethod.invoke(null, scriptArgs)
}
catch
case ScriptingException(msg) =>
println(s"Error: $msg")
case e:Exception =>
e.printStackTrace
println(s"Error: ${e.getMessage}")
sys.exit(1)

case e: java.lang.reflect.InvocationTargetException =>
throw e.getCause

private def writeJarfile(outDir: Path, scriptFile: File, scriptArgs:Array[String],
classpath:String, mainClassName: String): Unit =
val jarTargetDir: Path = Option(scriptFile.toPath.getParent) match {
case None => sys.error(s"no parent directory for script file [$scriptFile]")
case Some(parent) => parent
}

def scriptBasename = scriptFile.getName.takeWhile(_!='.')
val jarPath = s"$jarTargetDir/$scriptBasename.jar"

val cpPaths = classpath.split(pathsep).map {
// protect relative paths from being converted to absolute
case str if str.startsWith(".") && File(str).isDirectory => s"${str.withSlash}/"
case str if str.startsWith(".") => str.withSlash
case str => File(str).toURI.toURL.toString
}

import java.util.jar.Attributes.Name
val cpString:String = cpPaths.distinct.mkString(" ")
val manifestAttributes:Seq[(Name, String)] = Seq(
(Name.MANIFEST_VERSION, "1.0.0"),
(Name.MAIN_CLASS, mainClassName),
(Name.CLASS_PATH, cpString),
)
import dotty.tools.io.{Jar, Directory}
val jar = new Jar(jarPath)
val writer = jar.jarWriter(manifestAttributes:_*)
writer.writeAllFrom(Directory(outDir))
end writeJarfile

private def detectMainClassAndMethod(outDir: Path, classpath: String,
scriptFile: File): (String, Method) =
val outDirURL = outDir.toUri.toURL
val classpathUrls = classpath.split(pathsep).map(File(_).toURI.toURL)
val cl = URLClassLoader(classpathUrls :+ outDirURL)

def collectMainMethods(target: File, path: String): List[(String, Method)] =
val nameWithoutExtension = target.getName.takeWhile(_ != '.')
val targetPath =
if path.nonEmpty then s"${path}.${nameWithoutExtension}"
else nameWithoutExtension

if target.isDirectory then
for
packageMember <- target.listFiles.toList
membersMainMethod <- collectMainMethods(packageMember, targetPath)
yield membersMainMethod
else if target.getName.endsWith(".class") then
val cls = cl.loadClass(targetPath)
try
val method = cls.getMethod("main", classOf[Array[String]])
if Modifier.isStatic(method.getModifiers) then List((cls.getName, method)) else Nil
catch
case _: java.lang.NoSuchMethodException => Nil
else Nil
end collectMainMethods

val candidates = for
file <- outDir.toFile.listFiles.toList
method <- collectMainMethods(file, "")
yield method

candidates match
case Nil =>
throw ScriptingException(s"No main methods detected in script ${scriptFile}")
case _ :: _ :: _ =>
throw ScriptingException("A script must contain only one main method. " +
s"Detected the following main methods:\n${candidates.mkString("\n")}")
case m :: Nil => m
end match
end detectMainClassAndMethod

def pathsep = sys.props("path.separator")

extension(pathstr:String) {
def withSlash:String = pathstr.replace('\\', '/')
}
77 changes: 26 additions & 51 deletions compiler/src/dotty/tools/scripting/ScriptingDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,51 @@ package dotty.tools.scripting

import java.nio.file.{ Files, Path }
import java.io.File
import java.net.{ URL, URLClassLoader }
import java.lang.reflect.{ Modifier, Method }

import scala.jdk.CollectionConverters._

import dotty.tools.dotc.{ Driver, Compiler }
import dotty.tools.dotc.core.Contexts, Contexts.{ Context, ContextBase, ctx }
import dotty.tools.dotc.config.CompilerCommand
import dotty.tools.dotc.{ Driver }
import dotty.tools.dotc.core.Contexts, Contexts.{ Context, ctx }
import dotty.tools.io.{ PlainDirectory, Directory }
import dotty.tools.dotc.reporting.Reporter
import dotty.tools.dotc.config.Settings.Setting._

import sys.process._
import dotty.tools.dotc.util.ScriptSourceFile
import dotty.tools.io.AbstractFile

class ScriptingDriver(compilerArgs: Array[String], scriptFile: File, scriptArgs: Array[String]) extends Driver:
def compileAndRun(): Unit =
def compileAndRun(pack:(Path, String) => Unit = null): Unit =
val outDir = Files.createTempDirectory("scala3-scripting")
val (toCompile, rootCtx) = setup(compilerArgs :+ scriptFile.getAbsolutePath, initCtx.fresh)

given Context = rootCtx.fresh.setSetting(rootCtx.settings.outputDir,
new PlainDirectory(Directory(outDir)))

if doCompile(newCompiler, toCompile).hasErrors then
throw ScriptingException("Errors encountered during compilation")

try detectMainMethod(outDir, ctx.settings.classpath.value).invoke(null, scriptArgs)
val result = doCompile(newCompiler, toCompile)
if result.hasErrors then
throw ScriptingException(s"Errors encountered during compilation to dir [$outDir]")

try
if outDir.toFile.listFiles.toList.isEmpty then
sys.error(s"no files generated by compiling script ${scriptFile}")

Option(pack) match {
case None =>
case Some(func) =>
val javaClasspath = sys.props("java.class.path")
val pathsep = sys.props("path.separator")
val runtimeClasspath = s"${ctx.settings.classpath.value}$pathsep$javaClasspath"
func(outDir, runtimeClasspath)
}
catch
case e: java.lang.reflect.InvocationTargetException =>
throw e.getCause
finally
deleteFile(outDir.toFile)

def content(file: Path): Array[Char] = new String(Files.readAllBytes(file)).toCharArray
def scriptSource(file: Path) = ScriptSourceFile(AbstractFile.getFile(file), content(file))

end compileAndRun

private def deleteFile(target: File): Unit =
Expand All @@ -41,46 +56,6 @@ class ScriptingDriver(compilerArgs: Array[String], scriptFile: File, scriptArgs:
target.delete()
end deleteFile

private def detectMainMethod(outDir: Path, classpath: String): Method =
val outDirURL = outDir.toUri.toURL
val classpathUrls = classpath.split(":").map(File(_).toURI.toURL)
val cl = URLClassLoader(classpathUrls :+ outDirURL)

def collectMainMethods(target: File, path: String): List[Method] =
val nameWithoutExtension = target.getName.takeWhile(_ != '.')
val targetPath =
if path.nonEmpty then s"${path}.${nameWithoutExtension}"
else nameWithoutExtension

if target.isDirectory then
for
packageMember <- target.listFiles.toList
membersMainMethod <- collectMainMethods(packageMember, targetPath)
yield membersMainMethod
else if target.getName.endsWith(".class") then
val cls = cl.loadClass(targetPath)
try
val method = cls.getMethod("main", classOf[Array[String]])
if Modifier.isStatic(method.getModifiers) then List(method) else Nil
catch
case _: java.lang.NoSuchMethodException => Nil
else Nil
end collectMainMethods

val candidates = for
file <- outDir.toFile.listFiles.toList
method <- collectMainMethods(file, "")
yield method

candidates match
case Nil =>
throw ScriptingException("No main methods detected in your script")
case _ :: _ :: _ =>
throw ScriptingException("A script must contain only one main method. " +
s"Detected the following main methods:\n${candidates.mkString("\n")}")
case m :: Nil => m
end match
end detectMainMethod
end ScriptingDriver

case class ScriptingException(msg: String) extends RuntimeException(msg)
19 changes: 19 additions & 0 deletions compiler/test-resources/scripting/hashBang.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env scala
# comment
STUFF=nada
!#

def main(args: Array[String]): Unit =
System.err.printf("mainClassFromStack: %s\n",mainFromStack)
//assert(mainFromStack.contains("HashBang"),s"fromStack[$mainFromStack]")

lazy val mainFromStack:String = {
val result = new java.io.StringWriter()
new RuntimeException("stack").printStackTrace(new java.io.PrintWriter(result))
val stack = result.toString.split("[\r\n]+").toList
//for( s <- stack ){ System.err.printf("[%s]\n",s) }
stack.filter { str => str.contains(".main(") }.map {
_.replaceAll(".*[(]","").
replaceAll("[:)].*","")
}.distinct.take(1).mkString("")
}
22 changes: 22 additions & 0 deletions compiler/test-resources/scripting/mainClassOnStack.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env scala
export STUFF=nada
lots of other stuff that isn't valid scala
!#
object Zoo {
def main(args: Array[String]): Unit =
printf("script.name: %s\n",sys.props("script.name"))
printf("mainClassFromStack: %s\n",mainFromStack)
assert(mainFromStack == "Zoo",s"fromStack[$mainFromStack]")

lazy val mainFromStack:String = {
val result = new java.io.StringWriter()
new RuntimeException("stack").printStackTrace(new java.io.PrintWriter(result))
val stack = result.toString.split("[\r\n]+").toList
// for( s <- stack ){ System.err.printf("[%s]\n",s) }
val shortStack = stack.filter { str => str.contains(".main(") && ! str.contains("$") }.map {
_.replaceAll("[.].*","").replaceAll("\\s+at\\s+","")
}
// for( s <- shortStack ){ System.err.printf("[%s]\n",s) }
shortStack.take(1).mkString("|")
}
}
6 changes: 6 additions & 0 deletions compiler/test-resources/scripting/scriptName.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env scala

def main(args: Array[String]): Unit =
val name = sys.props("script.name")
printf("script.name: %s\n",name)
assert(name == "scriptName.scala")
Loading