Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

## Building strymonas

strymonas is built with SBT 1.3.8 or later and uses the [sbt-dotty](https://github.com/lampepfl/dotty/tree/master/sbt-dotty) plugin for Scala 3.
strymonas is built with SBT 1.5.0 or later for Scala 3.

* Use `sbt test` to run the tests.
* Use `sbt bench/jmh:run` to run the benchmarks
Expand Down
20 changes: 10 additions & 10 deletions bench/src/main/scala/benchmarks/BenchmarksStrymonas.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import scala.collection.mutable.ArrayBuffer
@Warmup(30)
@Fork(3)
class S {
given Toolbox = Toolbox.make(getClass.getClassLoader)
given Compiler = Compiler.make(getClass.getClassLoader)
import TestPipelines._
import S._

Expand Down Expand Up @@ -81,7 +81,7 @@ class S {

@Benchmark
def sum_staged_init_fresh_compiler(): Unit = {
given Toolbox = Toolbox.make(getClass.getClassLoader)
given Compiler = Compiler.make(getClass.getClassLoader)
run(sum)
}

Expand Down Expand Up @@ -111,7 +111,7 @@ class S {

@Benchmark
def sumOfSquares_staged_init_fresh_compiler(): Unit = {
given Toolbox = Toolbox.make(getClass.getClassLoader)
given Compiler = Compiler.make(getClass.getClassLoader)
run(sumOfSquares)
}

Expand Down Expand Up @@ -141,7 +141,7 @@ class S {

@Benchmark
def sumOfSquaresEven_staged_init_fresh_compiler(): Unit = {
given Toolbox = Toolbox.make(getClass.getClassLoader)
given Compiler = Compiler.make(getClass.getClassLoader)
run(sumOfSquaresEven)
}

Expand Down Expand Up @@ -171,7 +171,7 @@ class S {

@Benchmark
def cart_staged_init_fresh_compiler(): Unit = {
given Toolbox = Toolbox.make(getClass.getClassLoader)
given Compiler = Compiler.make(getClass.getClassLoader)
run(cart)
}

Expand Down Expand Up @@ -201,7 +201,7 @@ class S {

@Benchmark
def dotProduct_staged_init_fresh_compiler(): Unit = {
given Toolbox = Toolbox.make(getClass.getClassLoader)
given Compiler = Compiler.make(getClass.getClassLoader)
run(dotProduct)
}

Expand Down Expand Up @@ -231,7 +231,7 @@ class S {

@Benchmark
def flatMap_after_zip_staged_init_fresh_compiler(): Unit = {
given Toolbox = Toolbox.make(getClass.getClassLoader)
given Compiler = Compiler.make(getClass.getClassLoader)
run(flatMap_after_zip)
}

Expand Down Expand Up @@ -261,7 +261,7 @@ class S {

@Benchmark
def zip_after_flatMap_staged_init_fresh_compiler(): Unit = {
given Toolbox = Toolbox.make(getClass.getClassLoader)
given Compiler = Compiler.make(getClass.getClassLoader)
run(flatMap_take)
}

Expand Down Expand Up @@ -291,7 +291,7 @@ class S {

@Benchmark
def flatMap_take_staged_init_fresh_compiler(): Unit = {
given Toolbox = Toolbox.make(getClass.getClassLoader)
given Compiler = Compiler.make(getClass.getClassLoader)
run(zip_after_flatMap)
}

Expand Down Expand Up @@ -321,7 +321,7 @@ class S {

@Benchmark
def zip_flat_flat_staged_init_fresh_compiler(): Unit = {
given Toolbox = Toolbox.make(getClass.getClassLoader)
given Compiler = Compiler.make(getClass.getClassLoader)
run(zip_flat_flat)
}
}
Expand Down
27 changes: 13 additions & 14 deletions bench/src/main/scala/benchmarks/TestPipelines.scala
Original file line number Diff line number Diff line change
@@ -1,81 +1,80 @@
package benchmarks

import scala.quoted._
import scala.quoted.util._
import scala.quoted.staging._
import strymonas._

object TestPipelines {
given Toolbox = Toolbox.make(getClass.getClassLoader)
given Compiler = Compiler.make(getClass.getClassLoader)

def sum(using QuoteContext) = '{ (array: Array[Int]) =>
def sum(using Quotes) = '{ (array: Array[Int]) =>
${ Stream.of('array)
.fold('{0}, ((a, b) => '{ $a + $b })) }
}

def sumOfSquares(using QuoteContext) = '{ (array: Array[Int]) =>
def sumOfSquares(using Quotes) = '{ (array: Array[Int]) =>
${ Stream.of('{array})
.map((a) => '{ $a * $a })
.fold('{0}, ((a, b) => '{ $a + $b })) }
}

def sumOfSquaresEven(using QuoteContext) = '{ (array: Array[Int]) =>
def sumOfSquaresEven(using Quotes) = '{ (array: Array[Int]) =>
${ Stream.of('{array})
.filter((d) => '{ $d % 2 == 0 })
.map((a) => '{ $a * $a })
.fold('{0}, ((a, b) => '{ $a + $b })) }
}

def cart(using QuoteContext) = '{ (vHi: Array[Int], vLo: Array[Int]) =>
def cart(using Quotes) = '{ (vHi: Array[Int], vLo: Array[Int]) =>
${ Stream.of('{vHi})
.flatMap((d) => Stream.of('{vLo}).map((dp) => '{ $d * $dp }))
.fold('{0}, ((a, b) => '{ $a + $b })) }
}

def filter(using QuoteContext) = '{ (array: Array[Int]) =>
def filter(using Quotes) = '{ (array: Array[Int]) =>
${ Stream.of('{array})
.filter((d) => '{ $d % 2 == 0 })
.fold('{0}, ((a, b) => '{ $a + $b })) }
}

def take(using QuoteContext) = '{ (array: Array[Int]) =>
def take(using Quotes) = '{ (array: Array[Int]) =>
${ Stream.of('{array})
.take('{2})
.fold('{0}, ((a, b) => '{ $a + $b })) }
}

def flatMap_take(using QuoteContext) = '{ (array1: Array[Int], array2: Array[Int]) =>
def flatMap_take(using Quotes) = '{ (array1: Array[Int], array2: Array[Int]) =>
${ Stream.of('{array1})
.flatMap((d) => Stream.of('{array2}))
.take('{20000000})
.fold('{0}, ((a, b) => '{ $a + $b })) }
}

def dotProduct(using QuoteContext) = '{ (array1: Array[Int], array2: Array[Int]) =>
def dotProduct(using Quotes) = '{ (array1: Array[Int], array2: Array[Int]) =>
${ Stream.of('{array1})
.zip(((a: Expr[Int]) => (b: Expr[Int]) => '{ $a + $b }), Stream.of('{array2}))
.fold('{0}, ((a, b) => '{ $a + $b })) }
}

def flatMap_after_zip(using QuoteContext) = '{ (array1: Array[Int], array2: Array[Int]) =>
def flatMap_after_zip(using Quotes) = '{ (array1: Array[Int], array2: Array[Int]) =>
${ Stream.of('{array1})
.zip(((a: Expr[Int]) => (b: Expr[Int]) => '{ $a + $b }), Stream.of('{array1}))
.flatMap((d) => Stream.of('{array2}).map((dp) => '{ $d + $dp }))
.fold('{0}, ((a, b) => '{ $a + $b })) }
}

def zip_after_flatMap(using QuoteContext) = '{ (array1: Array[Int], array2: Array[Int]) =>
def zip_after_flatMap(using Quotes) = '{ (array1: Array[Int], array2: Array[Int]) =>
${ Stream.of('{array1})
.flatMap((d) => Stream.of('{array2}).map((dp) => '{ $d + $dp }))
.zip(((a: Expr[Int]) => (b: Expr[Int]) => '{ $a + $b }), Stream.of('{array1}) )
.fold('{0}, ((a, b) => '{ $a + $b })) }
}

def zip_flat_flat(using QuoteContext) = '{ (array1: Array[Int], array2: Array[Int]) =>
def zip_flat_flat(using Quotes) = '{ (array1: Array[Int], array2: Array[Int]) =>
${ Stream.of('{array1})
.flatMap((d) => Stream.of('{array2}).map((dp) => '{ $d + $dp }))
.zip(((a: Expr[Int]) => (b: Expr[Int]) => '{ $a + $b }), Stream.of('{array2}).flatMap((d) => Stream.of('{array1}).map((dp) => '{ $d + $dp })) )
.take('{20000000})
.fold('{0}, ((a, b ) => '{ $a + $b })) }
}
}
}
14 changes: 6 additions & 8 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
val dottyVersion = "0.24.0-RC1"
val dottyVersion = "3.0.3-RC1-bin-20210716-cc47c56-NIGHTLY"

lazy val root = project
.in(file("."))
Expand All @@ -8,11 +8,11 @@ lazy val root = project

scalaVersion := dottyVersion,

scalacOptions += "-language:experimental.namedTypeArguments",

libraryDependencies ++= Seq(
"ch.epfl.lamp" % "dotty_0.24" % dottyVersion,
"ch.epfl.lamp" % "dotty_0.24" % dottyVersion % "test->runtime",
"com.novocode" % "junit-interface" % "0.11" % "test",
"ch.epfl.lamp" %% "dotty-staging" % dottyVersion
scalaOrganization.value %% "scala3-staging" % dottyVersion
)
)

Expand All @@ -26,14 +26,12 @@ lazy val bench = project
scalaVersion := dottyVersion,

libraryDependencies ++= Seq(
"ch.epfl.lamp" % "dotty_0.24" % dottyVersion,
"ch.epfl.lamp" % "dotty_0.24" % dottyVersion % "test->runtime",
"ch.epfl.lamp" %% "dotty-staging" % dottyVersion
scalaOrganization.value %% "scala3-staging" % dottyVersion
),

javaOptions ++= Seq("-Xms6g", "-Xmx6g", "-Xss4m",
"-XX:+CMSClassUnloadingEnabled",
"-XX:ReservedCodeCacheSize=256m",
"-XX:-TieredCompilation", "-XX:+UseNUMA"
)
).enablePlugins(JmhPlugin)
).enablePlugins(JmhPlugin)
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=1.3.8
sbt.version=1.5.5
4 changes: 1 addition & 3 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
addSbtPlugin("ch.epfl.lamp" % "sbt-dotty" % "0.4.1")

addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.3.4")
addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.3.4")
14 changes: 6 additions & 8 deletions src/main/scala/strymonas/Stream.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
package strymonas

import scala.quoted._
import scala.quoted.util._
import scala.quoted.staging._
import scala.quoted.autolift
import imports._

/**
* Port of the strymonas library as described in O. Kiselyov et al., Stream fusion, to completeness (POPL 2017)
*/

type E[T] = QuoteContext ?=> Expr[T]
type E[T] = Quotes ?=> Expr[T]

case class Stream[A: Type](stream: StreamShape[Expr[A]]) extends StreamRaw {
import imports.Cardinality._
Expand Down Expand Up @@ -58,7 +56,7 @@ case class Stream[A: Type](stream: StreamShape[Expr[A]]) extends StreamRaw {
* @return a new stream consisting of all elements of the input stream that do satisfy the given
* predicate `pred`.
*/
def filter(pred: (Expr[A] => Expr[Boolean]))(using QuoteContext): Stream[A] = {
def filter(pred: (Expr[A] => Expr[Boolean]))(using Quotes): Stream[A] = {
val filterStream = (a: Expr[A]) =>
new Producer[Expr[A]] {

Expand All @@ -80,17 +78,17 @@ case class Stream[A: Type](stream: StreamShape[Expr[A]]) extends StreamRaw {


/** A stream containing the first `n` elements of this stream. */
def take(n: Expr[Int])(using QuoteContext): Stream[A] = Stream(takeRaw[Expr[A]](n, stream))
def take(n: Expr[Int])(using Quotes): Stream[A] = Stream(takeRaw[Expr[A]](n, stream))

/** zip **/
def zip[B: Type, C: Type](f: (Expr[A] => Expr[B] => Expr[C]), stream2: Stream[B])(using QuoteContext): Stream[C] = {
def zip[B: Type, C: Type](f: (Expr[A] => Expr[B] => Expr[C]), stream2: Stream[B])(using Quotes): Stream[C] = {
val Stream(stream_b) = stream2
Stream(mapRaw[(Expr[A], Expr[B]), Expr[C]]((t => k => k(f(t._1)(t._2))), zipRaw[A, Expr[B]](stream, stream_b)))
}
}

object Stream {
def of[A: Type](arr: Expr[Array[A]])(using QuoteContext): Stream[A] = {
def of[A: Type](arr: Expr[Array[A]])(using Quotes): Stream[A] = {
import imports.Cardinality._

val prod = new Producer[Expr[A]] {
Expand Down Expand Up @@ -125,4 +123,4 @@ object Stream {

Stream(Linear(prod))
}
}
}
11 changes: 5 additions & 6 deletions src/main/scala/strymonas/StreamRaw.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package strymonas

import scala.quoted._
import scala.quoted.util._
import imports._
import imports.Cardinality._

Expand Down Expand Up @@ -137,7 +136,7 @@ trait StreamRaw extends StreamRawOps {
* @tparam A the type of the producer's elements.
* @return a linear or nested stream aware of the variable reference to decrement.
*/
def takeRaw[A: Type](n: Expr[Int], stream: StreamShape[A])(using QuoteContext): StreamShape[A] = {
def takeRaw[A: Type](n: Expr[Int], stream: StreamShape[A])(using Quotes): StreamShape[A] = {
stream match {
case linear: Linear[A] => {
val enhancedProducer: Producer[(Var[Int], A)] = addCounter[A](n, linear.producer)
Expand Down Expand Up @@ -202,7 +201,7 @@ trait StreamRaw extends StreamRawOps {
}
}

def zipRaw[A: Type, B: Type](stream1: StreamShape[Expr[A]], stream2: StreamShape[B])(using QuoteContext): StreamShape[(Expr[A], B)] = {
def zipRaw[A: Type, B: Type](stream1: StreamShape[Expr[A]], stream2: StreamShape[B])(using Quotes): StreamShape[(Expr[A], B)] = {
(stream1, stream2) match {

case (Linear(producer1), Linear(producer2)) =>
Expand Down Expand Up @@ -270,7 +269,7 @@ trait StreamRaw extends StreamRawOps {
* @tparam A
* @return
*/
private def makeLinear[A: Type](stream: StreamShape[Expr[A]])(using QuoteContext): Producer[Expr[A]] = {
private def makeLinear[A: Type](stream: StreamShape[Expr[A]])(using Quotes): Producer[Expr[A]] = {
stream match {
case Linear(producer) => producer
case Nested(producer, nestedf) => {
Expand Down Expand Up @@ -374,7 +373,7 @@ trait StreamRaw extends StreamRawOps {
}
}

private def pushLinear[A, B, C](producer: Producer[A], nestedProducer: Producer[B], nestedf: (B => StreamShape[C]))(using QuoteContext): StreamShape[(A, C)] = {
private def pushLinear[A, B, C](producer: Producer[A], nestedProducer: Producer[B], nestedf: (B => StreamShape[C]))(using Quotes): StreamShape[(A, C)] = {
val newProducer = new Producer[(Var[Boolean], producer.St, B)] {

type St = (Var[Boolean], producer.St, nestedProducer.St)
Expand Down Expand Up @@ -430,4 +429,4 @@ trait StreamRaw extends StreamRawOps {
}
}
}
}
}
7 changes: 3 additions & 4 deletions src/main/scala/strymonas/StreamRawOps.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package strymonas

import scala.quoted._
import scala.quoted.util._

trait StreamRawOps {
def foldRaw[A](consumer: A => Expr[Unit], stream: StreamShape[A]): E[Unit]
def mapRaw[A, B](f: (A => (B => Expr[Unit]) => Expr[Unit]), stream: StreamShape[A]): StreamShape[B]
def flatMapRaw[A, B](f: (A => StreamShape[B]), stream: StreamShape[A]): StreamShape[B]
def takeRaw[A: Type](n: Expr[Int], stream: StreamShape[A])(using QuoteContext): StreamShape[A]
def zipRaw[A: Type, B: Type](stream1: StreamShape[Expr[A]], stream2: StreamShape[B])(using QuoteContext): StreamShape[(Expr[A], B)]
}
def takeRaw[A: Type](n: Expr[Int], stream: StreamShape[A])(using Quotes): StreamShape[A]
def zipRaw[A: Type, B: Type](stream1: StreamShape[Expr[A]], stream2: StreamShape[B])(using Quotes): StreamShape[(Expr[A], B)]
}
22 changes: 22 additions & 0 deletions src/main/scala/strymonas/Var.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package strymonas

import scala.quoted.*

sealed trait Var[T] {
def get(using qctx: Quotes): Expr[T]
def update(e: Expr[T])(using qctx: Quotes): Expr[Unit]
}

object Var {
def apply[T: Type, U: Type](init: Expr[T])(body: Var[T] => Expr[U])(using qctx: Quotes): Expr[U] = '{
var x = $init
${
body(
new Var[T] {
def get(using qctx: Quotes): Expr[T] = 'x
def update(e: Expr[T])(using qctx: Quotes): Expr[Unit] = '{ x = $e }
}
)
}
}
}
Loading